<a href="https://colab.research.google.com/github/steinhaug/stable-diffusion/blob/main/tool/batch_depth_frames.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. Extract all frames from videofile,  
2. crop frames to 768x512  
3. create depth frames from frames

In [None]:
# Install
from IPython.display import display, clear_output
!apt -y update -qq
!apt -y install -qq aria2
!pip install timm
clear_output();print('Done!')

In [22]:
#@title Notebook functions

import os
_B=True
_A=False
def return__isValidDir(directory):
    if os.path.isdir(directory):return _B
    else:return _A

In [26]:
video_file = "/content/video1.mp4"
frames_directory = "/content/frames/video1"

## 1.0 Xtract frames and crop frames

In [None]:
#@title . 1.1 Extract frames
if not return__isValidDir(frames_directory):
    os.makedirs(frames_directory)

!ffmpeg -i {video_file} -vf "scale=910:512" {frames_directory}/c01_%04d.png

clear_output();print(f'Frames extracted into {frames_directory}')

In [None]:
#@title . 1.2 Crop frames

from PIL import Image
import os
import sys

def process_images(input_directory, output_directory, target_width=768, target_height=512):
    # Create the output directory if it doesn't exist
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Get a list of all files in the input directory
    image_files = [f for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    for image_file in image_files:
        input_path = os.path.join(input_directory, image_file)
        output_path = os.path.join(output_directory, image_file)

        # Open the image
        with Image.open(input_path) as img:
            # Get the original image dimensions
            original_width, original_height = img.size

            # Calculate cropping dimensions to maintain aspect ratio and center the crop
            left = max(0, (original_width - target_width) // 2)
            top = max(0, (original_height - target_height) // 2)
            right = min(original_width, left + target_width)
            bottom = min(original_height, top + target_height)

            # Crop the image
            cropped_img = img.crop((left, top, right, bottom))

            # Save the cropped image
            cropped_img.save(output_path)

process_images(frames_directory, frames_directory)

clear_output();print(f'Frames cropped into 768x512')

## 2.0 Create depth maps

In [None]:
#@title . 2.1 Initialise MiDaS depth model
import torch

model_type = "DPT_Large"     # MiDaS v3 - Large     (highest accuracy, slowest inference speed)
#model_type = "DPT_Hybrid"   # MiDaS v3 - Hybrid    (medium accuracy, medium inference speed)
#model_type = "MiDaS_small"  # MiDaS v2.1 - Small   (lowest accuracy, highest inference speed)

midas = torch.hub.load("intel-isl/MiDaS", model_type)
clear_output();print('Depth model downloaded!')

from IPython.display import display, clear_output
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()

midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
else:
    transform = midas_transforms.small_transform

clear_output();print('Model loaded and ready...')

In [18]:
#@title . 2.2 Process frames
import cv2, os
import numpy as np
from PIL import Image

def process_images(input_directory, output_directory):
    # Create the output directory if it doesn't exist
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    # Get a list of all files in the input directory
    image_files = [f for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp'))]

    for image_file in image_files:
        input_path = os.path.join(input_directory, image_file)
        output_path = os.path.join(output_directory, image_file)


        img = cv2.imread(input_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        input_batch = transform(img).to(device)
        with torch.no_grad():
            prediction = midas(input_batch)
            prediction = torch.nn.functional.interpolate(
                prediction.unsqueeze(1),
                size=img.shape[:2],
                mode="bicubic",
                align_corners=False,
            ).squeeze()
        output = prediction.cpu().numpy()
        output_image = Image.fromarray(output.astype('uint8'))
        output_image.save(output_path)

process_images(frames_directory, f"{frames_directory}_depth")