In [3]:
import argparse
import cv2
import numpy as np
import os
import torch
import torch.nn.functional as F
from torchvision.transforms import Compose
from tqdm import tqdm

import decord

from depth_anything.dpt import DepthAnything
from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet

transform = Compose([
        Resize(
            width=518,
            height=518,
            resize_target=False,
            keep_aspect_ratio=True,
            ensure_multiple_of=14,
            resize_method='lower_bound',
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
    ])

def depth_analysis(raw_image):
    image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0

    h, w = image.shape[:2]

    image = transform({'image': image})['image']
    image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        depth = depth_anything(image)

    depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0]
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0

    depth = depth.cpu().numpy().astype(np.uint8)
    depth_color = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO) # 输出是这个图片
    
    return depth_color

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
   
encoder = 'vitb' # can also be 'vitb' or 'vitl'
depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{:}14'.format(encoder))

total_params = sum(param.numel() for param in depth_anything.parameters())
print('Total parameters: {:.2f}M'.format(total_params / 1e6))


depth_anything.eval()

# read video
video_path = './test.mp4' # <- input video
video = decord.VideoReader(video_path, height=518, width=518)

fps = video.get_avg_fps() # 每秒显示多少帧
nframes = len(video)  # 总的帧

num = nframes # 根据需要选择抽多少,此处假设全抽
frameDuration = nframes//num
indexes = range(0,nframes,int(frameDuration))
indexes = [i+int(frameDuration) for i in indexes if (i+int(frameDuration))<nframes] # 抽取视频帧

video_frames_raw = video.get_batch(indexes).asnumpy()
depth_colors = [] # depth pics

for i in tqdm(video_frames_raw):
    img = cv2.cvtColor(np.array(i),cv2.COLOR_BGR2RGB) # convert to cv2 style

    depth_color = depth_analysis(img)
    depth_colors.append(depth_color)

# save video
video_writer = cv2.VideoWriter('test.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (518, 518), True)
for pic in tqdm(depth_colors):
    video_writer.write(pic) # -> output video
    
video_writer.release()

Loading weights from local directory
Total parameters: 97.47M


100%|██████████| 2326/2326 [02:07<00:00, 18.21it/s]
100%|██████████| 2326/2326 [00:04<00:00, 516.37it/s]
