# Interpretable Depth Estimation Project
By Lien Huong and Evgeniia.

This work is based on the following papers:
1. You, Z., Tsai, Y.-H., Chiu, W.-C., and Li, G. (2021). Towards Interpretable Deep Networks for Monocular Depth Estimation. arXiv.
2. Papa, L., Russo, P., and Amerini, I. (2023). METER: A Mobile Vision Transformer Architecture for Monocular Depth Estimation. IEEE Transactions on Circuits and Systems for Video Technology, 33(10), 5882–5893.

## Imports

In [None]:
import os
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F

import h5py
from PIL import Image

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import urllib.request
from pathlib import Path

import gc

from einops import rearrange

## Globals

In [None]:
RGB_img_res = (None, 192, 256)

neuron_selective = True # True: train with neuron selectivity; False: train with balanced loss function

# Training configuration
num_epochs = 20
# For train with neuron selectivity
imde_alpha = -0.1
cfg_layers = ["dec_3", "enc_4"]

In [None]:
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'  # to avoid fragmentation

# Set memory-efficient PyTorch settings
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

def clear_gpu_memory():
    """Clear GPU memory cache"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Utils

In [None]:
def get_num_units_bins_and_nums_for_plt(layers, default_num_bins=64):
    num_units_layers_dict = {
        "enc_output": 160,
        "enc_0": 16, "enc_1": 16,"enc_2": 24, "enc_3": 24, "enc_4": 24,"enc_5": 48, # conv1, mv1-4
        "enc_6": 48, "enc_7": 64, "enc_8": 64, "enc_9": 80, "enc_10": 80, # meter0, mv5, meter1, mv6, meter2
        "dec_0": 64,"dec_1": 32,"dec_2": 16,"dec_3": 8 # conv1, upsampling1, upsampling2, upsampling3
    }
    num_units_list = [num_units_layers_dict[layer] for layer in layers]
    num_bins_list = [min(num_units, default_num_bins) for num_units in num_units_list]

    # for plotting
    plt_row_list, plt_col_list, figsize_w_list, figsize_h_list = [], [], [], []
    for num_units in num_units_list:
        if num_units == 64:
            plt_row, plt_col, figsize_h, figsize_w = 8, 8, 30, 50
        elif num_units == 48:
            plt_row, plt_col, figsize_h, figsize_w = 6, 8, 30, 50
        elif num_units == 8:
            plt_row, plt_col, figsize_h, figsize_w = 2, 4, 10, 20
        elif num_units == 16:
            plt_row, plt_col, figsize_h, figsize_w = 4, 4, 15, 20
        elif num_units == 24:
            plt_row, plt_col, figsize_h, figsize_w = 4, 6, 15, 30
        elif num_units == 80:
            plt_row, plt_col, figsize_h, figsize_w = 8, 10, 30, 60
        else: # 160
            plt_row, plt_col, figsize_h, figsize_w = 16, 10, 60, 50
    
        plt_row_list.append(plt_row)
        plt_col_list.append(plt_col)
        figsize_h_list.append(figsize_h)
        figsize_w_list.append(figsize_w)

    return num_units_list, num_bins_list, plt_row_list, plt_col_list, figsize_w_list, figsize_h_list


def get_sid_thresholds(K):
    alpha, beta = 0.75, 100.

    thresholds = [0., alpha]
    for i in range(1, K-1):
        ti = np.exp((np.log(beta + (1 - alpha)) * i) / (K-1)) - (1 - alpha)
        thresholds.append(ti)
    thresholds.append(beta)
    return thresholds


def discretize_depths(depths, K=None, thresholds=None):
    discrete_depths = torch.zeros(depths.shape).to(depths.device)
    assert K or thresholds, "At least one of them should be given."
    if not thresholds:
        thresholds = get_sid_thresholds(K)
    else:
        if not K:
            K = len(thresholds) - 1
        else:
            assert K == len(thresholds) - 1
    discrete_depths[torch.where(depths <= 0)] = -1  # invalid depth -> -1
    for i in range(K):
        discrete_depths[torch.where((depths > thresholds[i]) & (depths <= thresholds[i+1]))] = i
    return discrete_depths


def count_discrete_depth(discrete_depth, discrete_depth_cnt=None):
    device = torch.device("cpu") if isinstance(discrete_depth, np.ndarray) else discrete_depth.device
    if discrete_depth_cnt is None:
        discrete_depth_cnt = torch.zeros(1, dtype=torch.double).to(device)
    for img_i in range(len(discrete_depth)):
        bin_idx_max = int(discrete_depth[img_i].max())
        if bin_idx_max >= len(discrete_depth_cnt):  # need to extend discrete_depth_cnt
            discrete_depth_cnt = torch.cat([discrete_depth_cnt, torch.zeros(bin_idx_max - len(discrete_depth_cnt) + 1, dtype=torch.double).to(device)], axis=-1)
        for bin_idx_iter in range(bin_idx_max+1):
            cur_xy_tuples = torch.where(discrete_depth[img_i][0] == bin_idx_iter)
            if len(cur_xy_tuples[0]) == 0:
                continue
            discrete_depth_cnt[bin_idx_iter] += len(cur_xy_tuples[0])

    return discrete_depth_cnt


def add_fmaps_on_discrete_depths(discrete_depths, fmaps, discrete_depth_sum=None):
    device = fmaps.device
    if discrete_depth_sum is None:
        discrete_depth_sum = torch.zeros((fmaps.shape[1], 1), dtype=torch.double).to(device)
    h, w = discrete_depths.shape[2], discrete_depths.shape[3]
    fmaps = F.interpolate(fmaps, (h, w), mode='bilinear', align_corners=False)
    for img_i in range(len(discrete_depths)):
        bin_idx_max = int(discrete_depths[img_i].max())
        if bin_idx_max >= discrete_depth_sum.shape[1]:     # need to extend discrete_depth_sum
            discrete_depth_sum = torch.cat([discrete_depth_sum, torch.zeros((discrete_depth_sum.shape[0], bin_idx_max - discrete_depth_sum.shape[1] + 1), dtype=torch.double).to(device)], axis=-1)
        for bin_idx_iter in range(bin_idx_max+1):
            cur_xy_tuples = torch.where(discrete_depths[img_i][0] == bin_idx_iter)
            if len(cur_xy_tuples[0]) == 0:
                continue
            discrete_depth_sum[:, bin_idx_iter] += (fmaps[img_i, :, cur_xy_tuples[0], cur_xy_tuples[1]]).sum(-1)

    return discrete_depth_sum


def compute_selectivity(depth_avg_response, unit_max_bin):
    if isinstance(depth_avg_response, np.ndarray):
        depth_avg_response = torch.from_numpy(depth_avg_response)
    response_max = depth_avg_response[np.arange(len(depth_avg_response)), unit_max_bin]
    response_nonmax = (depth_avg_response.sum(1) - response_max) / (depth_avg_response.shape[1] - 1)
    selectivity_index = (response_max - response_nonmax) / (response_max + response_nonmax + 1e-15)
    return selectivity_index


def get_features(model, image, layer_list, also_get_output=True, is_training=False):
    return_list = []

    def forward(image, layer_list, also_get_output):
        x, enc_layer = model.encoder(image)
        output, dec_layer = model.decoder(x, enc_layer)
    
        for layer in layer_list:
            if layer == 'enc_output':
                return_list.append(x)
            elif layer.startswith('enc_'):  # e.g., 'enc_0', 'enc_1', ..., 'enc_10'
                idx = int(layer.split('_')[1])
                return_list.append(enc_layer[idx])
            elif layer.startswith('dec_'):  # e.g., 'dec_0', 'dec_1', ..., 'dec_3'
                idx = int(layer.split('_')[1])
                return_list.append(dec_layer[idx])
            else:
                raise NotImplementedError(f"Layer {layer} is not supported.")

        if also_get_output:
            return_list.append(output)

        return return_list

    if is_training:
        return_list = forward(image, layer_list, also_get_output)
    else:
        with torch.no_grad():
            return_list = forward(image, layer_list, also_get_output)

    return return_list

### Dissection

In [None]:
def compute_and_plot_selectivity(model, layers, data_loader, viz_plot=True):
    num_units_list, num_bins_list, plt_row_list, plt_col_list, figsize_w_list, figsize_h_list = get_num_units_bins_and_nums_for_plt(
        layers=layers
    )

    depth_cnt_list, depth_sum_list, depth_avg_list = [], [], []
    for layer_i in range(len(layers)):
        depth_cnt_list.append(torch.zeros(num_bins_list[layer_i], dtype=torch.double).to(device))
        depth_sum_list.append(torch.zeros((num_units_list[layer_i], num_bins_list[layer_i]), dtype=torch.double).to(device))
    sid_thresholds_list = [get_sid_thresholds(num_bins_list[layer_i]) for layer_i in range(len(layers))]

    for images, depths in tqdm(data_loader, desc='Dissection'):
        imgs = images.to(device)
        features_list = get_features(model, imgs, layers, also_get_output=False)
        for layer_i in range(len(layers)):
            depths = depths
            discrete_concepts = discretize_depths(depths, thresholds=sid_thresholds_list[layer_i])
            depth_cnt_list[layer_i] = count_discrete_depth(discrete_concepts, depth_cnt_list[layer_i])
            depth_sum_list[layer_i] = add_fmaps_on_discrete_depths(discrete_concepts, features_list[layer_i],
                                                                   depth_sum_list[layer_i])

    for layer_i in range(len(layers)):
        depth_cnt_list[layer_i] = depth_cnt_list[layer_i].cpu().numpy()
        depth_sum_list[layer_i] = depth_sum_list[layer_i].cpu().numpy()
        depth_avg_list.append(depth_sum_list[layer_i] / (depth_cnt_list[layer_i] + 1e-15))

    unit_max_bin_list, selectivity_index_list = [], []
    for layer_i in range(len(layers)):
        depth_avg = depth_avg_list[layer_i]
        # ---- compute selectivity ----
        depth_avg_abs = abs(depth_avg)
        unit_max_bin = depth_avg_abs.argmax(-1)
        selectivity_index = compute_selectivity(depth_avg_abs, unit_max_bin).numpy()
        unit_max_bin_list.append(unit_max_bin)
        selectivity_index_list.append(selectivity_index)
        selec_mean = selectivity_index.mean()
        print(f'selec_mean_{layers[layer_i]}', selectivity_index.mean())
        # ---- plot ----
        if viz_plot:
            sns.set_theme(style="darkgrid")
            fig, ax = plt.subplots(plt_row_list[layer_i], plt_col_list[layer_i],
                                   figsize=(figsize_w_list[layer_i], figsize_h_list[layer_i]))
        
            for i in range(plt_row_list[layer_i]):
                for j in range(plt_col_list[layer_i]):
                    unit_i = i * plt_col_list[layer_i] + j
                    ax[i, j].set_title(
                        f'Unit {unit_i}, Layer {layers[layer_i]}, Selectivity: {selectivity_index[unit_i]:.04f}',
                        fontsize=12)
                    ax[i, j].bar(
                        range(len(depth_cnt_list[layer_i])),
                        depth_avg_abs[unit_i], # using abs value here (originally depth_avg)
                        color='red', width=0.85)
                    plt.xticks(fontsize=12)
                    plt.yticks(fontsize=12)
                    plt.subplots_adjust(wspace=0.15)
            save_fig_name = f'{layers[layer_i]}-{selec_mean:.4f}.png'
            fig.savefig(save_fig_name, bbox_inches='tight')
            plt.show()

### Metric

In [None]:
def compute_rmse(outputs, depths):
    mse = torch.mean((outputs - depths) ** 2)
    rmse = torch.sqrt(mse)
    return rmse.item()

### Visualization

In [None]:
def visualize_predictions(model, train_loader, val_loader, device, epoch, num_samples=2, save=False):
    """Visualize predictions on training and validation data"""
    model.eval()
    
    # Get batches of training and validation data
    train_images, train_depths = next(iter(train_loader))
    val_images, val_depths = next(iter(val_loader))
    
    # Generate predictions
    with torch.no_grad():
        train_predictions = model(train_images.to(device))
        val_predictions = model(val_images.to(device))
    
    # Convert tensors to numpy arrays
    train_images = train_images.cpu().numpy()
    train_depths = train_depths.cpu().numpy()
    train_predictions = train_predictions.cpu().numpy()
    
    val_images = val_images.cpu().numpy()
    val_depths = val_depths.cpu().numpy()
    val_predictions = val_predictions.cpu().numpy()
    
    # Create figure
    fig, axes = plt.subplots(2*num_samples, 3, figsize=(15, 5*2*num_samples))
    plt.suptitle(f'Training and Validation Results - Epoch {epoch+1}', fontsize=16)
    
    def plot_sample(img, depth_gt, depth_pred, row_idx, title_prefix):
        # Original image
        img_display = np.transpose(img, (1, 2, 0))
        img_display = (img_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
        img_display = img_display.astype(np.uint8)
        axes[row_idx, 0].imshow(img_display)
        axes[row_idx, 0].set_title(f'{title_prefix} Input Image')
        axes[row_idx, 0].axis('off')
        
        # Ground truth depth
        axes[row_idx, 1].imshow(depth_gt[0], cmap='plasma')
        axes[row_idx, 1].set_title(f'{title_prefix} Ground Truth Depth')
        axes[row_idx, 1].axis('off')
        
        # Predicted depth
        axes[row_idx, 2].imshow(depth_pred[0], cmap='plasma')
        axes[row_idx, 2].set_title(f'{title_prefix} Predicted Depth')
        axes[row_idx, 2].axis('off')
    
    # Plot training samples
    for i in range(num_samples):
        plot_sample(train_images[i], train_depths[i], train_predictions[i], 2*i, 'Train:')
    
    # Plot validation samples
    for i in range(num_samples):
        plot_sample(val_images[i], val_depths[i], val_predictions[i], 2*i + 1, 'Val:')
    
    plt.tight_layout()
    if save:
        os.makedirs('training_visualizations', exist_ok=True)
        plt.savefig(f'training_visualizations/epoch_{epoch+1}.png')
        # plt.close()
    plt.show()

## Data: NYUDepthV2
https://www.kaggle.com/datasets/soumikrakshit/nyu-depth-v2

In [None]:
BATCH_SIZE = 64
NUM_WORKERS = 2

train_csv_path = '/kaggle/input/nyu-depth-v2/nyu_data/data/nyu2_train.csv'
base_data_path = '/kaggle/input/nyu-depth-v2/nyu_data'

In [None]:
class NYUDepthV2Dataset(Dataset):
    def __init__(self, filenames_df, transform=None):
        """
        Args:
            filenames_df: DataFrame containing image and depth map paths
            transform: Optional transform to be applied on the images
        """
        self.filenames_df = filenames_df
        self.transform = transform
        # Define resolutions as constants from the paper
        self.RGB_SIZE = (192, 256)  # (H, W) for input
        self.DEPTH_SIZE = (48, 64)  # (H, W) for depth maps

    def __len__(self):
        return len(self.filenames_df)

    def __getitem__(self, idx):
        # Get image and depth paths
        img_path = self.filenames_df.iloc[idx, 0]
        depth_path = self.filenames_df.iloc[idx, 1]
        
        # Load image and depth map
        image = Image.open(img_path).convert('RGB')
        depth = Image.open(depth_path)
        
        # Resize to the specified resolutions
        image = image.resize((self.RGB_SIZE[1], self.RGB_SIZE[0]), Image.BILINEAR)  # PIL uses (W, H)
        depth = depth.resize((self.DEPTH_SIZE[1], self.DEPTH_SIZE[0]), Image.NEAREST)  # PIL uses (W, H)
        
        # Convert to tensors
        if self.transform:
            image = self.transform(image)
        depth = torch.from_numpy(np.array(depth)).float().unsqueeze(0)  # Add channel dimension
        
        return image, depth

In [None]:
def get_dataset_loaders(train_csv_path, base_data_path, batch_size, num_workers):
    train_csv = Path(train_csv_path)
    base_path = Path(base_data_path)

    # Load the dataset paths
    filenames_df = pd.read_csv(train_csv, header=None)
    filenames_df[0] = filenames_df[0].map(lambda x: base_path/x)
    filenames_df[1] = filenames_df[1].map(lambda x: base_path/x)

    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Split dataset into train and validation sets
    total_size = len(filenames_df)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size
    
    train_df = filenames_df[:train_size]
    val_df = filenames_df[train_size:train_size + val_size]
    test_df = filenames_df[train_size + val_size:]
    
    print(f"Dataset splits: Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

    # Create dataloaders
    train_loader = DataLoader(
        NYUDepthV2Dataset(train_df, transform), 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=num_workers,
        pin_memory=True
    )
    val_loader = DataLoader(
        NYUDepthV2Dataset(val_df, transform), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )
    test_loader = DataLoader(
        NYUDepthV2Dataset(test_df, transform), 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = get_dataset_loaders(
    train_csv_path,
    base_data_path,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

## Network: METER model (XXS)

In [None]:
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, device, stride=1, depth=1, bias=False):
        super(SeparableConv2d, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, out_channels * depth,
                                   kernel_size=kernel_size,
                                   groups=depth,
                                   padding=1,
                                   stride=stride,
                                   bias=bias).to(device)
        self.pointwise = nn.Conv2d(out_channels * depth, out_channels, kernel_size=(1, 1), bias=bias).to(device)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU()
    )

def conv_nxn_bn(inp, oup, kernal_size=3, stride=1):
    return nn.Sequential(
        SeparableConv2d(in_channels=inp, out_channels=oup, kernel_size=kernal_size, stride=stride,
                        bias=False, device='cuda:0'),
        nn.BatchNorm2d(oup),
        nn.ReLU()
    )

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b p h n d -> b p n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class MV2Block(nn.Module):
    def __init__(self, inp, oup, stride=1, expansion=4):
        super().__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(inp * expansion)
        self.use_res_connect = self.stride == 1 and inp == oup
        
        self.conv = nn.Sequential(
            nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
            nn.BatchNorm2d(hidden_dim),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)

class MobileViTBlock(nn.Module):
    def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
        super().__init__()
        self.ph, self.pw = patch_size

        self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
        self.conv2 = conv_1x1_bn(channel, dim)

        self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)

        self.conv3 = conv_1x1_bn(dim, channel)
        self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)

    def forward(self, x):
        y = x.clone()

        # Local representations
        x = self.conv1(x)
        x = self.conv2(x)

        # Global representations
        _, _, h, w = x.shape
        x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
        x = self.transformer(x)
        x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
                      pw=self.pw)

        # Fusion
        x = self.conv3(x)
        x = torch.cat((x, y), 1)
        x = self.conv4(x)
        return x


class MobileViT(nn.Module):
    def __init__(self, image_size, dims, channels, expansion=4, kernel_size=3, patch_size=(2, 2)):
        super().__init__()
        ih, iw = image_size
        ph, pw = patch_size
        assert ih % ph == 0 and iw % pw == 0

        L = [1, 1, 1]

        self.conv1 = conv_nxn_bn(3, channels[0], stride=2)

        self.mv2 = nn.ModuleList([])
        self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
        self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
        self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
        self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
        self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))

        self.mvit = nn.ModuleList([])
        self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
        self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
        self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))

        self.conv2 = conv_1x1_bn(channels[-2], channels[-1])

    def forward(self, x):
        econv1 = self.conv1(x)
        mv0 = self.mv2[0](econv1)

        mv1 = self.mv2[1](mv0)
        mv2 = self.mv2[2](mv1)
        mv3 = self.mv2[3](mv2)

        mv4 = self.mv2[4](mv3)
        meter0 = self.mvit[0](mv4)

        mv5 = self.mv2[5](meter0)
        meter1 = self.mvit[1](mv5)

        mv6 = self.mv2[6](meter1)
        meter2 = self.mvit[2](mv6)
        x = self.conv2(meter2)
        return x, [econv1, mv0, mv1, mv2, mv3, mv4, meter0, mv5, meter1, mv6, meter2]

def mobilevit_xxs():
    enc_type = 'xxs'
    dims = [64, 80, 96]
    channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 160]
    return MobileViT((RGB_img_res[1], RGB_img_res[2]), dims, channels, expansion=2)

class UpSample_layer(nn.Module):
    def __init__(self, inp, oup, flag, sep_conv_filters, name, device):
        super(UpSample_layer, self).__init__()
        self.flag = flag
        self.name = name
        self.conv2d_transpose = nn.ConvTranspose2d(inp, oup, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1),
                                                   dilation=1, output_padding=(1, 1), bias=False)
        self.end_up_layer = nn.Sequential(
            SeparableConv2d(sep_conv_filters, oup, kernel_size=(3, 3), device=device),
            nn.ReLU()
        )

    def forward(self, x, enc_layer):
        x = self.conv2d_transpose(x)
        if x.shape[-1] != enc_layer.shape[-1]:
            enc_layer = torch.nn.functional.pad(enc_layer, pad=(1, 0), mode='constant', value=0.0)
        if x.shape[-1] != enc_layer.shape[-1]:
            enc_layer = torch.nn.functional.pad(enc_layer, pad=(0, 1), mode='constant', value=0.0)
        x = torch.cat([x, enc_layer], dim=1)
        x = self.end_up_layer(x)
        return x

class decoder(nn.Module):
    def __init__(self, device):
        super(decoder, self).__init__()
        self.conv2d_in = nn.Conv2d(160, 64, kernel_size=(1, 1), padding='same', bias=False)
        self.ups_block_1 = UpSample_layer(64, 32, flag=True, sep_conv_filters=96, name='up1', device=device)
        self.ups_block_2 = UpSample_layer(32, 16, flag=False, sep_conv_filters=64, name='up2', device=device)
        self.ups_block_3 = UpSample_layer(16, 8, flag=False, sep_conv_filters=32, name='up3', device=device)
        self.conv2d_out = nn.Conv2d(8, 1, kernel_size=(3, 3), padding='same', bias=False)

    def forward(self, x, enc_layer_list):
        dconv0 = self.conv2d_in(x)
        up1 = self.ups_block_1(dconv0, enc_layer_list[7]) #enc_layer_list[3]
        up2 = self.ups_block_2(up1, enc_layer_list[5]) #enc_layer_list[2]
        up3 = self.ups_block_3(up2, enc_layer_list[2]) #enc_layer_list[1]
        x = self.conv2d_out(up3)
        return x, [dconv0, up1, up2, up3]

In [None]:
class build_METER_model(nn.Module):
    def __init__(self, device):
        super(build_METER_model, self).__init__()
        self.encoder = mobilevit_xxs()
        self.decoder = decoder(device=device)

    def forward(self, x):
        x, enc_layer = self.encoder(x)
        x, dec_layer = self.decoder(x, enc_layer)
        return x

## Train
### Baseline loss function

In [None]:
def gaussian(window_size, sigma):
    gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()


def create_window(window_size, channel=1):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window


def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1

        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range

    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)

    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2

    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2

    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity

    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)

    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)

    if full:
        return ret, cs
    return ret


class Sobel(nn.Module):
    def __init__(self):
        super(Sobel, self).__init__()
        self.edge_conv = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1, bias=False)
        edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
        edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
        edge_k = np.stack((edge_kx, edge_ky))

        edge_k = torch.from_numpy(edge_k).float().view(2, 1, 3, 3)
        self.edge_conv.weight = nn.Parameter(edge_k)

        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        out = self.edge_conv(x)
        out = out.contiguous().view(-1, 2, x.size(2), x.size(3))
        return out

In [None]:
class balanced_loss_function(nn.Module):
    def __init__(self, device):
        super(balanced_loss_function, self).__init__()
        self.cos = nn.CosineSimilarity(dim=1, eps=0)
        self.get_gradient = Sobel().to(device)
        self.device = device
        self.lambda_1 = 0.5
        self.lambda_2 = 100
        self.lambda_3 = 100

    def forward(self, output, depth):
        with torch.no_grad():
            ones = torch.ones(depth.size(0), 1, depth.size(2), depth.size(3)).float().to(self.device)

        depth_grad = self.get_gradient(depth)
        output_grad = self.get_gradient(output)

        depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(depth)
        depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(depth)
        output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(depth)
        output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(depth)

        depth_normal = torch.cat((-depth_grad_dx, -depth_grad_dy, ones), 1)
        output_normal = torch.cat((-output_grad_dx, -output_grad_dy, ones), 1)

        loss_depth = torch.abs(output - depth).mean()
        loss_dx = torch.abs(output_grad_dx - depth_grad_dx).mean()
        loss_dy = torch.abs(output_grad_dy - depth_grad_dy).mean()
        loss_normal = self.lambda_2 * torch.abs(1 - self.cos(output_normal, depth_normal)).mean()

        loss_ssim = (1 - ssim(output, depth, val_range=1000.0)) * self.lambda_3

        loss_grad = (loss_dx + loss_dy) / self.lambda_1

        loss = loss_depth + loss_ssim + loss_normal + loss_grad
        return loss

### Neuron selectivity regularizer

In [None]:
def get_loss_with_regularizer(layers, model, imgs, depths, sid_thresholds_list, device, criterion):
    features_list = get_features(model, imgs, layers, also_get_output=True, is_training=True)
    preds = features_list[-1]
    baseline_loss = criterion(preds, depths)

    # ---- selectivity regularizer ----
    selec_list = []
    for layer_i in range(len(layers)):
        discret_depths = discretize_depths(depths, thresholds=sid_thresholds_list[layer_i])
        concept_cnt = count_discrete_depth(discret_depths, torch.zeros(len(sid_thresholds_list[layer_i])-1, dtype=torch.double).to(device))
        features_list[layer_i] = torch.abs(features_list[layer_i])
        concept_sum = add_fmaps_on_discrete_depths(discret_depths, features_list[layer_i], torch.zeros((features_list[layer_i].shape[1], len(sid_thresholds_list[layer_i])-1), dtype=torch.double).to(device))
        concept_avg = concept_sum / (concept_cnt + 1e-15)

        num_units, num_bins = concept_avg.shape[0], concept_avg.shape[1]
        assert num_units >= num_bins
        unit_class = (torch.arange(num_units)//(num_units/num_bins)).to(concept_avg.device).long()
        layer_selec = compute_selectivity(concept_avg, unit_class)
        concept_cnt_iszero = (unit_class == torch.nonzero(concept_cnt==0)).sum(0, dtype=torch.bool)
        layer_selec[concept_cnt_iszero] = 0   # If d_k is absent in a batch, the unit k will be simply disregarded from the loss computation
        layer_selec = layer_selec.mean()
        selec_list.append(layer_selec)
    selectivity = sum(selec_list)
    return baseline_loss, selectivity

### Training loop

In [None]:
# Training configuration
best_val_loss = float('inf')
model = build_METER_model(device).to(device)
criterion = balanced_loss_function(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

# For train with neuron selectivity
sid_thresholds_list = []
num_bins_list = get_num_units_bins_and_nums_for_plt(cfg_layers)[1]
for layer_i in range(len(cfg_layers)):
    sid_thresholds_list.append(get_sid_thresholds(num_bins_list[layer_i]))

In [None]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss, selectivity_sum, baseline_loss_sum = 0, 0, 0
    for images, depths in tqdm(train_loader, desc='Training'):
        images = images.to(device)
        depths = depths.to(device)

        loss = None
        if neuron_selective:
            baseline_loss, selectivity = get_loss_with_regularizer(cfg_layers, model, images, depths, sid_thresholds_list, device, criterion)
            loss = baseline_loss + imde_alpha * selectivity
            selectivity_sum += selectivity.item()
            baseline_loss_sum += baseline_loss.item()
        else:
            outputs = model(images)
            loss = criterion(outputs, depths)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss/len(train_loader), selectivity_sum/len(train_loader)

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    rmse = 0
    
    with torch.no_grad():
        for images, depths in tqdm(val_loader, desc='Validating'):
            images = images.to(device)
            depths = depths.to(device)
            
            outputs = model(images)
            loss = None
            if neuron_selective:
                baseline_loss, selectivity = get_loss_with_regularizer(cfg_layers, model, images, depths, sid_thresholds_list, device, criterion)
                loss = baseline_loss + imde_alpha * selectivity
            else:
                loss = criterion(outputs, depths)
            
            total_loss += loss.item()
            rmse += compute_rmse(outputs, depths)

    return total_loss / len(val_loader), rmse / len(val_loader)

In [None]:
# Training loop
train_losses, val_losses, selectivities = [], [], []
rmse_values = []
print(f'***** Neuron selectivite: {neuron_selective} *****')
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    
    # Train
    train_loss, selectivity = train_one_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    selectivities.append(selectivity)
    print(f'Training Loss: {train_loss:.4f}')
    if neuron_selective:
        print(f'Selectivity: {selectivity:.4f}')
    
    # Validate
    val_loss, rmse = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    rmse_values.append(rmse)
    print(f'Validation Loss: {val_loss:.4f} RMSE: {rmse:.4f}')
    
    # Visualize predictions on training and validation data
    visualize_predictions(model, train_loader, val_loader, device, epoch, save=True)
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'selectivity': selectivity,
            'val_loss': val_loss,
            'rmse': rmse,
        }, 'best_meter_model.pth')
        print(f'best rmse: {rmse}')
    print('-' * 50)
    clear_gpu_memory()
# Save logs
np.save('train_losses.npy', train_losses)
np.save('val_losses.npy', val_losses)
np.save('selectivities.npy', selectivities)
np.save('rmse_values.npy', rmse_values)

### Plots

In [None]:
# Load logs
# train_losses = np.load('train_losses.npy')
# val_losses = np.load('val_losses.npy')
# selectivities = np.load('selectivities.npy')
# rmse_values = np.load('rmse_values.npy')

In [None]:
# Plot train & val losses
plt_epochs = range(1, num_epochs+1)
plt.plot(plt_epochs, train_losses, label='Training Loss')
plt.plot(plt_epochs, val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.xticks(np.arange(0, num_epochs+1, 2))
plt.legend(loc='best')
plt.savefig(f'train_val_loss.png')
plt.show()

In [None]:
# Plot RMSE
plt_epochs = range(1, num_epochs+1)
plt.plot(plt_epochs, rmse_values, label='RMSE')
plt.title('Root mean squared error')
plt.xlabel('Epochs')
plt.ylabel('RMSE')
plt.xticks(np.arange(0, num_epochs+1, 2))
plt.legend(loc='best')
plt.savefig(f'rmse.png')
plt.show()

In [None]:
# Plot selectivity
plt_epochs = range(1, num_epochs+1)
plt.plot(plt_epochs, selectivities, label='Selectivity')
plt.title('Neuron depth selectivity')
plt.xlabel('Epochs')
plt.ylabel('Selectivity')
plt.xticks(np.arange(0, num_epochs+1, 2))
plt.legend(loc='best')
plt.savefig(f'selectivity.png')
plt.show()

## Evaluation

In [None]:
# !gdown --folder https://drive.google.com/drive/folders/1uz1VgAqYP17idqKvXAKiFTJpoPXDYQ5m -O model
# !ls

In [None]:
model = build_METER_model(device).to(device)
model.load_state_dict(torch.load('best_meter_model.pth')['model_state_dict'])
# If loading models from gDrive
# model.load_state_dict(torch.load('model/best_meter_model-2.pth')['model_state_dict'])
model.eval()

In [None]:
layers_to_compute_selectivity = ["enc_0", "enc_1", "enc_2","enc_3","enc_4", "enc_5", "enc_6","enc_7","enc_8", "enc_9", "enc_10", "enc_output"]
compute_and_plot_selectivity(model, layers=layers_to_compute_selectivity, data_loader=val_loader, viz_plot=False)

In [None]:
layers_to_compute_selectivity = ["dec_0", "dec_1", "dec_2","dec_3"]
compute_and_plot_selectivity(model, layers=layers_to_compute_selectivity, data_loader=val_loader, viz_plot=False)