In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from PIL import Image
from tqdm import tqdm
import random
import matplotlib.pyplot as plt
import numpy as np
import time
from IPython.display import clear_output

In [None]:
# Set adn freeze seed
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed = 105
set_seed(seed)

In [None]:
# Our Model's Architecture
class DeepMetricEye(nn.Module):
    def __init__(self):
        super(DeepMetricEye, self).__init__()
        # Downsample
        # input 1*256*256
        # Encoder 1
        self.conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace=True)
        # 64*256*256

        # ----------------- #
        self.maxpool1 = nn.MaxPool2d(2, stride=2)
        # ----------------- #
        # 64*128*128

        # Encoder 2
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.relu4 = nn.ReLU(inplace=True)
        # 128*128*128

        # ----------------- #
        self.maxpool2 = nn.MaxPool2d(2, stride=2)
        # ----------------- #
        # 128*64*64

        # Encoder 3
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.relu5 = nn.ReLU(inplace=True)
        self.conv6 = nn.Conv2d(256, 256, 3, padding=1)
        self.bn6 = nn.BatchNorm2d(256)
        self.relu6 = nn.ReLU(inplace=True)
        # 256*64*64

        # ----------------- #
        self.maxpool3 = nn.MaxPool2d(2, stride=2)
        # ----------------- #
        # 256*32*32

        # Encoder 4
        self.conv7 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn7 = nn.BatchNorm2d(512)
        self.relu7 = nn.ReLU(inplace=True)
        self.conv8 = nn.Conv2d(512, 512, 3, padding=1)
        self.bn8 = nn.BatchNorm2d(512)
        self.relu8 = nn.ReLU(inplace=True)
        # 512*32*32

        # ----------------- #
        self.maxpool4 = nn.MaxPool2d(2, stride=2)
        # ----------------- #
        # 512*16*16

        # Encoder 5
        self.conv9 = nn.Conv2d(512, 1024, 3, padding=1)
        self.bn9 = nn.BatchNorm2d(1024)
        self.relu9 = nn.ReLU(inplace=True)
        self.conv10 = nn.Conv2d(1024, 1024, 3, padding=1)
        self.bn10 = nn.BatchNorm2d(1024)
        self.relu10 = nn.ReLU(inplace=True)
        self.dropout0 = nn.Dropout2d(p=0.5)
        # 1024*16*16

        # ----------------- #
        self.maxpool5 = nn.MaxPool2d(2, stride=2)
        # ----------------- #
        # 1024*8*8

        # Bottleneck
        self.conv20 = nn.Conv2d(1024, 2048, 3, padding=1)
        self.bn20 = nn.BatchNorm2d(2048)
        self.relu20 = nn.ReLU(inplace=True)
        self.conv21 = nn.Conv2d(2048, 2048, 3, padding=1)
        self.bn21 = nn.BatchNorm2d(2048)
        self.relu21 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(p=0.5)
        # 2048*8*8

        # Upsample
        # Decoder 5
        # E3 256x64x64
        self.D5_pool1 = nn.MaxPool2d(4, stride=4) # 256x16x16
        self.D5_conv1 = nn.Conv2d(256, 256, 3, padding=1)
        self.D5_bn1 = nn.BatchNorm2d(256)
        self.D5_relu1 = nn.ReLU(inplace=True) # 256x16x16

        # E4 512x32x32
        self.D5_pool2 = nn.MaxPool2d(2, stride=2) # 512x16x16
        self.D5_conv2 = nn.Conv2d(512, 256, 3, padding=1)
        self.D5_bn2 = nn.BatchNorm2d(256)
        self.D5_relu2 = nn.ReLU(inplace=True)
        self.D5_conv3 = nn.Conv2d(256, 256, 3, padding=1)
        self.D5_bn3 = nn.BatchNorm2d(256)
        self.D5_relu3 = nn.ReLU(inplace=True) # 256x16x16

        # E5 1024x16x16
        self.D5_conv4 = nn.Conv2d(1024, 256, 3, padding=1) # 256x16x16
        self.D5_bn4 = nn.BatchNorm2d(256)
        self.D5_relu4 = nn.ReLU(inplace=True)
        self.D5_conv5 = nn.Conv2d(256, 256, 3, padding=1)
        self.D5_bn5 = nn.BatchNorm2d(256)
        self.D5_relu5 = nn.ReLU(inplace=True) # 256x16x16

        # Bottleneck 2048x8x8
        self.D5_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # 2048*16*16
        self.D5_conv6 = nn.Conv2d(2048, 256, 3, padding=1)
        self.D5_bn6 = nn.BatchNorm2d(256)
        self.D5_relu6 = nn.ReLU(inplace=True)
        self.D5_conv7 = nn.Conv2d(256, 256, 3, padding=1)
        self.D5_bn7 = nn.BatchNorm2d(256)
        self.D5_relu7 = nn.ReLU(inplace=True) # 256x16x16

        # Concat
        self.D5_conv8 = nn.Conv2d(256*4, 256*4, 3, padding=1)
        self.D5_bn8 = nn.BatchNorm2d(256*4)
        self.D5_relu8 = nn.ReLU(inplace=True)

        # Decoder 4
        # E3 256x64x64
        self.D4_pool1 = nn.MaxPool2d(2, stride=2) # 256x32x32
        self.D4_conv1 = nn.Conv2d(256, 128, 3, padding=1)
        self.D4_bn1 = nn.BatchNorm2d(128)
        self.D4_relu1 = nn.ReLU(inplace=True)
        self.D4_conv2 = nn.Conv2d(128, 128, 3, padding=1)
        self.D4_bn2 = nn.BatchNorm2d(128)
        self.D4_relu2 = nn.ReLU(inplace=True) # 128x32x32

        # E4 512x32x32
        self.D4_conv3 = nn.Conv2d(512, 128, 3, padding=1) # 128x32x32
        self.D4_bn3 = nn.BatchNorm2d(128)
        self.D4_relu3 = nn.ReLU(inplace=True)
        self.D4_conv4 = nn.Conv2d(128, 128, 3, padding=1)
        self.D4_bn4 = nn.BatchNorm2d(128)
        self.D4_relu4 = nn.ReLU(inplace=True) # 128x32x32

        # D5 1024x16x16
        self.D4_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # 1024*32*32
        self.D4_conv5 = nn.Conv2d(1024, 128, 3, padding=1)
        self.D4_bn5 = nn.BatchNorm2d(128)
        self.D4_relu5 = nn.ReLU(inplace=True)
        self.D4_conv6 = nn.Conv2d(128, 128, 3, padding=1)
        self.D4_bn6 = nn.BatchNorm2d(128)
        self.D4_relu6 = nn.ReLU(inplace=True) # 128x32x32

        # Bottleneck 2048x8x8
        self.D4_upsample2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
        # 2048*32*32
        self.D4_conv7 = nn.Conv2d(2048, 128, 3, padding=1)
        self.D4_bn7 = nn.BatchNorm2d(128)
        self.D4_relu7 = nn.ReLU(inplace=True)
        self.D4_conv8 = nn.Conv2d(128, 128, 3, padding=1)
        self.D4_bn8 = nn.BatchNorm2d(128)
        self.D4_relu8 = nn.ReLU(inplace=True) # 128x32x32

        # Concat
        self.D4_conv9 = nn.Conv2d(128*4, 128*4, 3, padding=1)
        self.D4_bn9 = nn.BatchNorm2d(128*4)
        self.D4_relu9 = nn.ReLU(inplace=True)

        # Decoder 3
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # 512*64*64
        self.conv13 = nn.Conv2d(512, 256, 3, padding=1)
        self.bn13 = nn.BatchNorm2d(256)
        self.relu13 = nn.ReLU(inplace=True)
        self.conv14 = nn.Conv2d(256, 256, 3, padding=1)
        self.bn14 = nn.BatchNorm2d(256)
        self.relu14 = nn.ReLU(inplace=True)
        # 256*64*64

        # Decoder 2
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # 256*128*128
        self.conv15 = nn.Conv2d(256, 128, 3, padding=1)
        self.bn15 = nn.BatchNorm2d(128)
        self.relu15 = nn.ReLU(inplace=True)
        self.conv16 = nn.Conv2d(128, 128, 3, padding=1)
        self.bn16 = nn.BatchNorm2d(128)
        self.relu16 = nn.ReLU(inplace=True)
        # 128*128*128

        # Decoder 1
        self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # 128*256*256
        self.conv17 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn17 = nn.BatchNorm2d(64)
        self.relu17 = nn.ReLU(inplace=True)
        self.conv18 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn18 = nn.BatchNorm2d(64)
        self.relu18 = nn.ReLU(inplace=True)
        # 64*256*256

        self.conv19 = nn.Conv2d(64, 1, 1, padding=0)
        # 1*256*256
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Encoder 1
        x1 = self.relu1(self.bn1(self.conv1(x)))
        x2 = self.relu2(self.bn2(self.conv2(x1)))

        x3 = self.maxpool1(x2)

        # Encoder 2
        x3 = self.relu3(self.bn3(self.conv3(x3)))
        x4 = self.relu4(self.bn4(self.conv4(x3)))

        x5 = self.maxpool2(x4)

        # Encoder 3
        x5 = self.relu5(self.bn5(self.conv5(x5)))
        x6 = self.relu6(self.bn6(self.conv6(x5)))

        x7 = self.maxpool3(x6)

        # Encoder 4
        x7 = self.relu7(self.bn7(self.conv7(x7)))
        x8 = self.relu8(self.bn8(self.conv8(x7)))

        x9 = self.maxpool4(x8)

        # Encoder 5
        x9 = self.relu9(self.bn9(self.conv9(x9)))
        x10 = self.relu10(self.bn10(self.conv10(x9)))
        x10 = self.dropout0(x10)

        x11 = self.maxpool5(x10)

        # Bottleneck
        x11 = self.relu20(self.bn20(self.conv20(x11)))
        x12 = self.relu21(self.bn21(self.conv21(x11)))
        x12 = self.dropout1(x12)

        # Decoder 5
        d5_e3 = self.D5_pool1(x6)
        d5_e3 = self.D5_relu1(self.D5_bn1(self.D5_conv1(d5_e3)))

        d5_e4 = self.D5_pool2(x8)
        d5_e4 = self.D5_relu2(self.D5_bn2(self.D5_conv2(d5_e4)))
        d5_e4 = self.D5_relu3(self.D5_bn3(self.D5_conv3(d5_e4)))

        d5_e5 = self.D5_relu4(self.D5_bn4(self.D5_conv4(x10)))
        d5_e5 = self.D5_relu5(self.D5_bn5(self.D5_conv5(d5_e5)))

        d5_BN = self.D5_upsample1(x12)
        d5_BN = self.D5_relu6(self.D5_bn6(self.D5_conv6(d5_BN)))
        d5_BN = self.D5_relu7(self.D5_bn7(self.D5_conv7(d5_BN)))

        d5 = torch.cat((d5_e3, d5_e4, d5_e5, d5_BN), 1)
        d5 = self.D5_relu8(self.D5_bn8(self.D5_conv8(d5)))

        # Decoder 4
        d4_e3 = self.D4_pool1(x6)
        d4_e3 = self.D4_relu1(self.D4_bn1(self.D4_conv1(d4_e3)))
        d4_e3 = self.D4_relu2(self.D4_bn2(self.D4_conv2(d4_e3)))

        d4_e4 = self.D4_relu3(self.D4_bn3(self.D4_conv3(x8)))
        d4_e4 = self.D4_relu4(self.D4_bn4(self.D4_conv4(d4_e4)))

        d4_e5 = self.D4_upsample1(d5)
        d4_e5 = self.D4_relu5(self.D4_bn5(self.D4_conv5(d4_e5)))
        d4_e5 = self.D4_relu6(self.D4_bn6(self.D4_conv6(d4_e5)))

        d4_BN = self.D4_upsample2(x12)
        d4_BN = self.D4_relu7(self.D4_bn7(self.D4_conv7(d4_BN)))
        d4_BN = self.D4_relu8(self.D4_bn8(self.D4_conv8(d4_BN)))

        d4 = torch.cat((d4_e3, d4_e4, d4_e5, d4_BN), 1)
        d4 = self.D4_relu9(self.D4_bn9(self.D4_conv9(d4)))

        # Decoder 3
        x = self.upsample2(d4)
        x = self.relu13(self.bn13(self.conv13(x)))
        x = self.relu14(self.bn14(self.conv14(x)))

        # Decoder 2
        x = self.upsample3(x)
        x = self.relu15(self.bn15(self.conv15(x)))
        x = self.relu16(self.bn16(self.conv16(x)))

        # Decoder 1
        x = self.upsample4(x)
        x = self.relu17(self.bn17(self.conv17(x)))
        x = self.relu18(self.bn18(self.conv18(x)))

        x = self.conv19(x)
        x = self.sigmoid(x)
        return x

In [None]:
# Loading Data
def load_data(data_dir):
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = DepthDataset(data_dir, 'train', transform)
    val_dataset = DepthDataset(data_dir, 'val', transform)
    test_dataset = DepthDataset(data_dir, 'test', transform)
    return train_dataset, val_dataset, test_dataset

# Creating DataLoader
def create_dataloader(dataset, batch_size, shuffle=True, num_workers=0):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    return dataloader

In [None]:
# Depth Dataset
class DepthDataset(Dataset):
    def __init__(self, data_dir, split, transform=None):
        self.data_dir = data_dir
        self.split = split
        self.transform = transform
        self.image_paths = []
        self.depth_paths = []
        with open(f'{data_dir}/{split}.txt', 'r') as f:
            for line in f:
                line = line.strip()
                self.image_paths.append(f'{data_dir}/{split}/rgb/{line}.png')
                self.depth_paths.append(f'{data_dir}/{split}/depth/{line}.png')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index]).convert('RGB')
        r, g, b = image.split()
        image = b
        depth = Image.open(self.depth_paths[index]).convert('L')
        if self.transform is not None:
            image = self.transform(image)
            depth = self.transform(depth)
        return image, depth


In [None]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                for i in tqdm(range(100)):
                    clear_output(wait=True)
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        global  loss_save_name
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        loss_save_name = val_loss
        self.val_loss_min = val_loss

In [None]:
# Loss Function
class ReverseHuberLoss(nn.Module):
    def __init__(self):
        super(ReverseHuberLoss, self).__init__()

    def forward(self, pred, target):
        abs_error = (pred - target).abs()
        c = 0.25 * abs_error.max()
        mask = abs_error <= c
        loss = (mask.float() * abs_error ** 2 + (1 - mask.float()) * abs_error).mean()
        return loss

In [None]:
# Training Function
def train(model, train_dataloader, val_dataloader, epochs, lr, device, early_stopping_patience):
    criterion = ReverseHuberLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_loss_history = []
    val_loss_history = []
    early_stopping = EarlyStopping(patience=early_stopping_patience, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, targets in tqdm(train_dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss /= len(train_dataloader)
        train_loss_history.append(train_loss)
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for inputs, targets in tqdm(val_dataloader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
            val_loss /= len(val_dataloader)
            val_loss_history.append(val_loss)

        early_stopping(val_loss, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        print(f'Epoch {epoch+1}/{epochs}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}')

    model.load_state_dict(torch.load('checkpoint.pt'))
    return train_loss_history, val_loss_history

# Test Function
def test(model, test_dataloader, device):
    criterion = ReverseHuberLoss()
    test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
        test_loss /= len(test_dataloader)
        print(f'Test loss: {test_loss:.4f}')

In [None]:
# Initialise Model and Start Training
data_dir = './train_data'
batch_size = 30
epochs = 150
lr = 1e-4
es_patience = 20
loss_save_name = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

start_time = time.time()

train_dataset, val_dataset, test_dataset = load_data(data_dir)

train_dataloader = create_dataloader(train_dataset, batch_size)
val_dataloader = create_dataloader(val_dataset, batch_size, shuffle=True)
test_dataloader = create_dataloader(test_dataset, batch_size, shuffle=True)

model = DeepMetricEye().to(device)

train_loss_history, val_loss_history = train(model, train_dataloader, val_dataloader, epochs, lr, device, es_patience)

test(model, test_dataloader, device)

filename = f"model_L4_test_bs{batch_size}_lr{lr}_epoch{epochs}_ReverseHuber_Loss{loss_save_name:.4f}_seed{seed}.pth"
save_path = "./model/" + filename
torch.save(model.state_dict(), save_path)
print("Model saved to %s" % save_path)
end_time = time.time()
print(f'Time cost: {(end_time - start_time) / 60:.2f} minutes')