In [None]:
from depth_pro.utils import load_rgb
from depth_pro import depth_pro
import matplotlib.pyplot as plt
from matplotlib import cm
from PIL import Image
import numpy as np
import torch

In [None]:
# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms(device=torch.device("mps"))
model.eval();

In [None]:
image_path = '../data/patryk_reka.jpg'

In [None]:
# Load and preprocess an image.
image, _, f_px = load_rgb(image_path)
image = transform(image)

In [None]:
# Run inference.
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"]  # Depth in [m].
focallength_px = prediction["focallength_px"]  # Focal length in pixels.

In [None]:
prediction

In [None]:
focallength_px

In [None]:
depth.shape

In [None]:
depth_cpu = depth.cpu()
depth_np = depth_cpu.numpy()

depth_min = depth_np.min()
depth_max = depth_np.max()

if depth_max - depth_min > 0:
    depth_normalized = (depth_max - depth_np) / (depth_max - depth_min)
else:
    depth_normalized = np.zeros_like(depth_np)

depth_scaled = (depth_normalized * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_scaled)
depth_image.save('depth_grayscale.png')

colormap = plt.get_cmap('viridis')
depth_colored = colormap(depth_normalized)  # RGBA
depth_colored = (depth_colored[:, :, :3] * 255).astype(np.uint8)  # Remove alpha channel
depth_image_colored = Image.fromarray(depth_colored)
depth_image_colored.save('depth_colored.png')

In [None]:
original_image = Image.open(image_path).convert('RGB').rotate(-90, expand=True)

fig, axs = plt.subplots(1, 3, figsize=(24, 8))
axs[0].imshow(original_image)
axs[0].set_title('Original Image')
axs[0].axis('off')

im1 = axs[1].imshow(depth_scaled, cmap='gray')
axs[1].set_title('Grayscale Depth Map')
axs[1].axis('off')
fig.colorbar(im1, ax=axs[1], fraction=0.046, pad=0.04, label='Depth Intensity')

im2 = axs[2].imshow(depth_colored)
axs[2].set_title('Colored Depth Map (Viridis)')
axs[2].axis('off')

sm = cm.ScalarMappable(cmap=colormap, norm=plt.Normalize(vmin=depth_min, vmax=depth_max))
sm.set_array([])
fig.colorbar(sm, ax=axs[2], fraction=0.046, pad=0.04, label='Depth (m)')

plt.tight_layout()
plt.show()

