In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
from glob import glob
from natsort import natsorted

import numpy as np
from PIL import Image
import cv2

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
import torchvision
import pytorch_lightning as pl



In [9]:
class DepthEstimator(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.__dict__.update(kwargs)

        # Define data paths
        self.vkitti_dir = self.vkitti_dir
        self.rgb_dirname = "rgb"
        self.depth_dirname = "depth"

        self.batch_size = self.__dict__.get("batch_size", 1)
        self.learning_rate = self.__dict__.get("learning_rate", 1e-5)

        self.stem = torchvision.models.segmentation.fcn_resnet50(progress=True)
        self.stem.classifier[4] = torch.nn.Conv2d(
            512, 1, kernel_size=(1, 1), stride=(1, 1)
        )  # Replace last layer with 1-channel conv

    def forward(self, x):
        return self.stem(x)["out"]

    def configure_optimizers(self):
        return [torch.optim.Adam(self.parameters(), lr=self.learning_rate)]

    def training_step(self, batch, batch_nb):
        rgb = batch["rgb"]
        depth = batch["depth"]
        pred = self.forward(rgb)
        loss = F.mse_loss(pred, depth)
        return {"loss": loss}

    def validation_step(self, batch, batch_nb):
        rgb = batch["rgb"]
        depth = batch["depth"]
        pred = self.forward(rgb)
        loss = F.mse_loss(pred, depth)
        return {"val_loss": loss}

    def validation_end(self, outputs):
        # avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        # tensorboard_logs = {"val_loss": avg_loss}
        # return {"avg_loss": avg_loss, "log": tensorboard_logs}
        return {}

In [14]:
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, Lambda
# Define transforms
transforms = {
    "rgb": Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    "depth": Compose([
        ToTensor(),
    ])
}

In [15]:
# Data
VKITTI_DIR = "../data/vkitti"
BATCH_SIZE = 2

train_dataset = VkittiImageDataSet(VKITTI_DIR, ("rgb", "depth"), transforms=transforms)
train_dataset[0]
# val_dataset = VkittiImageDataSet(VKITTI_DIR, ("rgb", "depth"), transforms=transforms)

# train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
# val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

Subset 'rgb' contains 42520 files
Subset 'depth' contains 42520 files
Found a total of 42520 valid files.


{'rgb': tensor([[[ 1.2728,  1.2728,  1.2728,  ..., -1.4158, -1.3302, -1.0219],
          [ 1.2728,  1.2728,  1.2728,  ..., -1.4500, -1.3302, -1.1247],
          [ 1.2899,  1.2899,  1.2899,  ..., -1.4329, -1.3987, -1.3130],
          ...,
          [ 1.9407,  1.9407,  1.9407,  ..., -1.9980, -1.9980, -1.9980],
          [ 1.9235,  1.9235,  1.9407,  ..., -1.9980, -1.9980, -1.9980],
          [ 1.9235,  1.9235,  1.9407,  ..., -1.9980, -1.9980, -1.9980]],
 
         [[ 1.5532,  1.5532,  1.5532,  ..., -1.3179, -1.2304, -0.9503],
          [ 1.5532,  1.5532,  1.5532,  ..., -1.3529, -1.2654, -1.0553],
          [ 1.5707,  1.5707,  1.5707,  ..., -1.3704, -1.3354, -1.2304],
          ...,
          [ 2.1134,  2.1134,  2.1134,  ..., -1.9482, -1.9482, -1.9482],
          [ 2.0959,  2.0959,  2.1134,  ..., -1.9482, -1.9482, -1.9482],
          [ 2.0959,  2.0959,  2.1134,  ..., -1.9482, -1.9482, -1.9482]],
 
         [[ 1.9428,  1.9428,  1.9428,  ..., -1.0898, -1.0027, -0.7064],
          [ 1.9428,  

In [5]:
model = DepthEstimator(vkitti_dir=VKITTI_DIR, batch_size=2)
trainer = pl.Trainer(progress_bar_refresh_rate=1, gpus=[0])
trainer.fit(model)

Subset 'rgb' contains 42520 files
Subset 'depth' contains 42520 files
Found a total of 42520 valid files.


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

Subset 'rgb' contains 42520 files
Subset 'depth' contains 42520 files
Found a total of 42520 valid files.
Subset 'rgb' contains 42520 files
Subset 'depth' contains 42520 files
Found a total of 42520 valid files.



1