<a href="https://colab.research.google.com/github/skj092/monocular-depth-estimation/blob/main/mde_pt_keras_doc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# !pip install -qq pytorch_msssim

In [13]:
import sys
sys.path.append('src/mde/')
import pandas as pd
import numpy as np
import cv2
import os
import torch
from sklearn.model_selection import train_test_split
import wandb
from dotenv import load_dotenv
from loss_fns import calculate_loss, abs_rel, log10_mae, log10_rmse, threshold_accuracy, delta1, delta2, delta3, CombinedDepthLoss
from model2 import DepthEstimationModel
# from fastai.vision.all import *
import keras
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
from engine import train_epoch, validate_epoch

In [None]:
# Data Download
# annotation_folder = "/dataset/"
# if not os.path.exists(os.path.abspath(".") + annotation_folder):
#     annotation_zip = keras.utils.get_file(
#         "val.tar.gz",
#         cache_subdir=os.path.abspath("."),
#         origin="http://diode-dataset.s3.amazonaws.com/val.tar.gz",
#         extract=True,
#     )

# wandb.login(key=os.getenv("WANDB_API_KEY"))
# wandb.init(project="DepthEstimation", name="FastAI_Training",)

# load_dotenv()

In [None]:
# data loading and preprocessing
train_path = "val_extracted/val/indoors"
filelist = []
for root, dirs, files in os.walk(train_path):
    for file in files:
        filelist.append(os.path.join(root, file))

filelist.sort()

images = sorted([x for x in filelist if x.endswith(".png")])
depths = sorted([x for x in filelist if x.endswith("_depth.npy")])
masks = sorted([x for x in filelist if x.endswith("_depth_mask.npy")])

# Get base names
depth_bases = {os.path.basename(x).replace("_depth.npy", "") for x in depths}
mask_bases = {os.path.basename(x).replace("_depth_mask.npy", "") for x in masks}
valid_bases = depth_bases & mask_bases

# Filter valid files only
valid_images = [x for x in images if os.path.basename(x).replace(".png", "") in valid_bases]
valid_depths = [x for x in depths if os.path.basename(x).replace("_depth.npy", "") in valid_bases]
valid_masks = [x for x in masks if os.path.basename(x).replace("_depth_mask.npy", "") in valid_bases]

# Build DataFrame
data = {
    "image": valid_images,
    "depth": valid_depths,
    "mask": valid_masks,
}
df = pd.DataFrame(data)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
print(f"Total train samples: {len(df)}")
train_df, valid_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Training samples: {len(train_df)}, Validation samples: {len(valid_df)}")

In [None]:
# Dataset and DataLoader setup
def open_image(fn):
    img = cv2.imread(fn)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (256, 256))
    return torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0

def load_depth_masked(row):
    depth = np.load(row['depth']).squeeze()
    mask = np.load(row['mask']) > 0
    epsilon = 1e-6
    max_depth = min(300, np.percentile(depth[mask], 99))
    depth = np.clip(depth, epsilon, max_depth)
    depth_log = np.zeros_like(depth)
    depth_log[mask] = np.log(depth[mask])
    depth_log = np.clip(depth_log, np.log(epsilon), np.log(max_depth))
    depth_log = cv2.resize(depth_log, (256, 256))
    with np.errstate(invalid='ignore', divide='ignore'):
        depth_norm = (depth_log / np.log(max_depth))
        depth_norm = np.nan_to_num(depth_norm, nan=0.0, posinf=0.0, neginf=0.0)
    depth_uint8 = (depth_norm * 255).astype(np.uint8)
    return torch.tensor(depth_uint8, dtype=torch.float32).unsqueeze(0) / 255.0  # Add channel dimension and normalize

def get_x(row): return open_image(row['image'])
def get_y(row): return load_depth_masked(row)

class DepthDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        x = get_x(row)
        y = get_y(row)
        return x, y

train_ds = DepthDataset(train_df)
valid_ds = DepthDataset(valid_df)
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
valid_dl = DataLoader(valid_ds, batch_size=16, shuffle=False, num_workers=4)

In [None]:
# visualize a batch of data
def visualize_batch(dl):
    for x, y in dl:
        x = x.permute(0, 2, 3, 1).numpy()  # Convert to [B, H, W, C] format
        y = y.squeeze().numpy()  # Remove channel dimension
        fig, axes = plt.subplots(nrows=2, ncols=len(x), figsize=(15, 5))
        for i in range(len(x)):
            axes[0, i].imshow(x[i])
            axes[0, i].set_title('Input Image')
            axes[0, i].axis('off')
            axes[1, i].imshow(y[i], cmap='gray')
            axes[1, i].set_title('Depth Map')
            axes[1, i].axis('off')
        plt.tight_layout()
        plt.show()
        break  # Show only one batch

visualize_batch(train_dl)

In [17]:
loss_func = CombinedDepthLoss(ssim_loss_weight=1.0, l1_loss_weight=1.0, edge_loss_weight=1.0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DepthEstimationModel(in_channels=3)
model = model.to(device)

# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-4, steps_per_epoch=len(train_dl), epochs=10)

def train_model(model, train_dl, valid_dl, optimizer, loss_func, device, epochs=10):
    for epoch in range(epochs):
        train_loss = train_epoch(model, train_dl, optimizer, loss_func, device)
        valid_loss, metrics = validate_epoch(model, valid_dl, loss_func, device)

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")
        print(f"Metrics: {metrics}")

        # Update learning rate
        # scheduler.step()
        # Log to Weights & Biases
        # wandb.log({
        #     'epoch': epoch + 1,
        #     'train_loss': train_loss,
        #     'valid_loss': valid_loss,
        #     **metrics
        # })
    return model

Total train samples: 325
Training samples: 260, Validation samples: 65


In [19]:
model = train_model(model, train_dl, valid_dl, optimizer, loss_func, device, epochs=10)
torch.save(model.state_dict(), 'depth_estimation_model.pth')

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:29<00:00,  1.72s/it]
  0%|                                                                                                                                                                   | 0/5 [00:02<?, ?it/s]


NameError: name 'abs_rel' is not defined

In [None]:
import matplotlib.pyplot as plt

def visualize_predictions(model, valid_dl, device):
    model.eval()
    with torch.no_grad():
        for i, (x, y) in enumerate(valid_dl):
            x = x.to(device)
            pred = model(x)
            pred = F.interpolate(pred, size=(256, 256), mode='bilinear', align_corners=False)

            # Get only the first image from the batch
            pred_img = pred[0].cpu().squeeze().numpy()
            y_img = y[0].cpu().squeeze().numpy()

            # Convert to uint8 for visualization
            pred_vis = (pred_img * 255).astype(np.uint8)
            y_vis = (y_img * 255).astype(np.uint8)

            # Plot side-by-side
            fig, axs = plt.subplots(1, 2, figsize=(10, 4))
            axs[0].imshow(pred_vis, cmap='gray')
            axs[0].set_title(f'Predicted Depth {i}')
            axs[0].axis('off')

            axs[1].imshow(y_vis, cmap='gray')
            axs[1].set_title(f'Ground Truth Depth {i}')
            axs[1].axis('off')

            plt.tight_layout()
            plt.show()

            if i >= 4:  # Show only 5 batches
                break


In [None]:
visualize_predictions(model, valid_dl, device)