In [3]:
import torch
import random
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoImageProcessor, DPTForDepthEstimation
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import pandas as pd

In [4]:
random.seed(98)

In [5]:
IMAGE_SIZE = (256, 255)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
class MaskedMAE(nn.Module):
    def __init__(self, valid_mask=True, max_depth=None):
        super(MaskedMAE, self).__init__()
        
        self.valid_mask = valid_mask
        self.max_depth = max_depth

    def mae(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            if self.max_depth is not None:
                valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
            input = input[valid_mask]
            target = target[valid_mask]

        mae = torch.abs(input - target).mean()
        return mae
    
    def forward(self, depth_pred, depth_gt):
        metric_mae = self.mae(depth_pred, depth_gt)
        return metric_mae

In [7]:
class MaskedR2Score(nn.Module):
    def __init__(self, valid_mask=True, max_depth=None):
        super(MaskedR2Score, self).__init__()

        self.valid_mask = valid_mask
        self.max_depth = max_depth

    def r2(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            if self.max_depth is not None:
                valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
            input = input[valid_mask]
            target = target[valid_mask]

        mean_target = torch.mean(target)
        ss_total = torch.sum((target - mean_target)**2)
        ss_residual = torch.sum((input - target)**2)

        r2 = 1 - (ss_residual / ss_total)
        return r2
    
    def forward(self, depth_pred, depth_gt):
        metric_r2 = self.r2(depth_pred, depth_gt)
        return metric_r2

In [8]:
class CustomNPZDataset(Dataset):
    def __init__(self, path, transform=None, image_transforms=None):
        self.path = path
        self.files = list(Path(path).glob('*.npz'))
        self.transform = transform
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        with np.load(str(self.files[item])) as data:
            X_numpy = data['X']
            y_numpy = data['y']
        X_torch = torch.from_numpy(X_numpy)
        y_torch = torch.from_numpy(y_numpy).unsqueeze(0)
        if self.transform is not None:
            X_torch = self.transform(X_torch)
            y_torch = self.transform(y_torch)
        if self.image_transforms is not None:
            X_torch = self.transform(X_torch)
        return X_torch, y_torch

In [None]:
models = {
    "KITTI": "facebook/dpt-dinov2-small-kitti",
    "NYUd": "facebook/dpt-dinov2-small-nyu",
    "Scratch": "model",
}

In [None]:
CONFIG_NAME = "facebook/dinov2-small"
image_processor = AutoImageProcessor.from_pretrained(CONFIG_NAME)

In [None]:
TEST_PATH = "data/thumbnails/test"
validation_dataset = CustomNPZDataset(path=TEST_PATH)
validation_loader = DataLoader(validation_dataset, batch_size=32)

In [None]:
mae_fn = MaskedMAE()
r2_fn = MaskedR2Score()

In [None]:
def run_model(data):
    inputs, labels = data

    images = [Image.fromarray(input.numpy().transpose(1, 2, 0)) for input in inputs]

    inputs = image_processor(images=images, return_tensors="pt")

    inputs = inputs.to(DEVICE)
    
    outputs = model(**inputs)

    predicted_depth = outputs['predicted_depth']

    predictions = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=IMAGE_SIZE,
        mode="bicubic",
        align_corners=False,
    )

    labels = labels.to(DEVICE)
    
    mae = mae_fn(predictions, labels)
    r2 = r2_fn(predictions, labels)
    
    return mae, r2

In [None]:
model_data = {"names": [], "maes": [], "r2s": []}
for name in models:
    model = DPTForDepthEstimation.from_pretrained(models[name])
    model.cuda()
    model.eval()
    running_vmae = 0.0
    running_vr2 = 0.0
    with torch.no_grad():
        for i, vdata in tqdm(enumerate(validation_loader), total=len(validation_loader)):
            vmae, vr2 = run_model(vdata)
            running_vmae += vmae
            running_vr2 += vr2
    avg_vmae = running_vmae / len(validation_loader)
    avg_vr2 = running_vr2 / len(validation_loader)
    model_data["names"].append(name)
    model_data["maes"].append(avg_vmae)
    model_data["r2s"].append(avg_vr2)

In [None]:
maes_clean = map(lambda x: round(x.item(), 2), model_data["maes"])
r2s_clean = map(lambda x: round(x.item(), 2), model_data["r2s"])
clean_data = {"Model": model_data["names"], "MAE": list(maes_clean), "R2": list(r2s_clean)}
df = pd.DataFrame.from_dict(clean_data)
df

In [None]:
num_images = 4
for image in range(num_images):
    X, y = validation_dataset[random.randint(0, len(validation_dataset)-1)]
    images = [Image.fromarray(X.numpy().transpose(1, 2, 0))]
    image_processor = AutoImageProcessor.from_pretrained('facebook/dinov2-small')
    inputs = image_processor(images=images, return_tensors="pt")
    inputs = inputs.to(DEVICE)
    fig, axes = plt.subplots(len(models), 3, figsize=(12, len(models) * 4))
    for i, name in enumerate(models):
        model = DPTForDepthEstimation.from_pretrained(models[name])
        model.cuda()
        model.eval()
        with torch.no_grad():
            outputs = model(**inputs)
        predicted_depth = outputs['predicted_depth']
        predictions = torch.nn.functional.interpolate(
            predicted_depth.unsqueeze(1),
            size=IMAGE_SIZE,
            mode="bicubic",
            align_corners=False,
        )
        axes[i][0].imshow(X.numpy().transpose(1, 2, 0))
        axes[i][0].set_title("Satelite Image")
        axes[i][1].imshow(predictions.squeeze(0).squeeze(0).cpu().numpy(), cmap='viridis')
        axes[i][1].set_title(f"{name} Prediction")
        axes[i][2].imshow(y.squeeze(0).numpy(), cmap='viridis')
        axes[i][2].set_title("LiDAR Ground Truth")
        
    fig.savefig(f"image-{image}.png", dpi=fig.dpi, bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(X.numpy().transpose(1, 2, 0))
axes[0].set_title("Satelite Image")
axes[1].imshow(predictions.squeeze(0).squeeze(0).cpu().numpy(), cmap='viridis')
axes[1].set_title(f"{name} Prediction")
axes[2].imshow(y.squeeze(0).numpy(), cmap='viridis')
axes[2].set_title("LiDAR Ground Truth")

fig.savefig(f"image-{image}_Scratch.png", dpi=fig.dpi, bbox_inches='tight')