## Setup

In [1]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('All required packages have been successfully installed!')

Installing torchprofile...
All required packages have been successfully installed!


In [2]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs

assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

In [3]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7c222dfef3f0>

## Load Model

In [4]:
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import requests

2024-05-15 14:32:38.036933: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-15 14:32:38.037049: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-15 14:32:38.163260: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
processor = DPTImageProcessor.from_pretrained("Intel/dpt-swinv2-tiny-256")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-swinv2-tiny-256").to('cuda')

preprocessor_config.json:   0%|          | 0.00/425 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/3.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/164M [00:00<?, ?B/s]

In [6]:
model.eval()

DPTForDepthEstimation(
  (backbone): Swinv2Backbone(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=3, bias=False)
                  )
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, 

## Config

In [7]:
config ={
    "General":{
        "device":"cuda",
        "type":"full",
        "model": model,
        "emb_dim":768,
        "hooks":[2, 5, 8, 11],
        "read":"projection",
        "resample_dim":256,
        "optim":"adam",
        "lr_backbone":1e-5,
        "lr_scratch":3e-4,
        "loss_depth":"ssi",
        "loss_segmentation":"ce",
        "momentum":0.9,
        "epochs":3,
        "batch_size":1,
        "path_model":"models",
        "path_predicted_images":"output",
        "seed":0,
        "patch_size":16
    },
    "Dataset":{
        "paths":{
            "path_dataset":"/kaggle/input",
            "list_datasets":["inria-fod", "nyuv2-fod", "posetrack-fod"],
            "path_images":"images",
            "path_segmentations":"segmentations",
            "path_depths":"depths"
        },
        "extensions":{
            "ext_images":".jpg",
            "ext_segmentations":".png",
            "ext_depths":".jpg"
        },
        "splits":{
            "split_train":0.6,
            "split_val":0.2,
            "split_test":0.2
        },
        "transforms":{
            "resize":256,
            "p_flip":0.5,
            "p_crop":0.3,
            "p_rot":0.2
        },
        "classes":{
            "1": {
                "name": "person",
                "color": [150,5,61]
            }
        }
    },
    "wandb":{
        "enable": False,
        "username":"younesbelkada",
        "images_to_show":3,
        "im_h":540,
        "im_w":980
    }
}

Helper function

In [8]:
import os, errno
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from glob import glob
from PIL import Image
from torchvision import transforms, utils

In [9]:
def get_total_paths(path, ext):
    return glob(os.path.join(path, '*'+ext))

In [10]:
def get_transforms(config):
    im_size = config['Dataset']['transforms']['resize']
    transform_image = transforms.Compose([
        transforms.Resize((im_size, im_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    transform_depth = transforms.Compose([
        transforms.Resize((im_size, im_size)),
        transforms.Grayscale(num_output_channels=1) ,
        transforms.ToTensor()
    ])
    transform_seg = transforms.Compose([
        transforms.Resize((im_size, im_size), interpolation=transforms.InterpolationMode.NEAREST),
        ToMask(config['Dataset']['classes']),
    ])
    return transform_image, transform_depth, transform_seg

In [11]:
def get_splitted_dataset(config, split, dataset_name, path_images, path_depths, path_segmentation):
    list_files = [os.path.basename(im) for im in path_images]
    np.random.seed(config['General']['seed'])
    np.random.shuffle(list_files)
    if split == 'train':
        selected_files = list_files[:int(len(list_files)*config['Dataset']['splits']['split_train'])]
    elif split == 'val':
        selected_files = list_files[int(len(list_files)*config['Dataset']['splits']['split_train']):int(len(list_files)*config['Dataset']['splits']['split_train'])+int(len(list_files)*config['Dataset']['splits']['split_val'])]
    else:
        selected_files = list_files[int(len(list_files)*config['Dataset']['splits']['split_train'])+int(len(list_files)*config['Dataset']['splits']['split_val']):]

    path_images = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_images'], im[:-4]+config['Dataset']['extensions']['ext_images']) for im in selected_files]
    path_depths = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_depths'], im[:-4]+config['Dataset']['extensions']['ext_depths']) for im in selected_files]
    path_segmentation = [os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_segmentations'], im[:-4]+config['Dataset']['extensions']['ext_segmentations']) for im in selected_files]
    return path_images, path_depths, path_segmentation

Custom augmentation

In [12]:
class ToMask(object):
    """
        Convert a 3 channel RGB image into a 1 channel segmentation mask
    """
    def __init__(self, palette_dictionnary):
        self.nb_classes = len(palette_dictionnary)
        # sort the dictionary of the classes by the sum of rgb value -> to have always background = 0
        # self.converted_dictionnary = {i: v for i, (k, v) in enumerate(sorted(palette_dictionnary.items(), key=lambda item: sum(item[1])))}
        self.palette_dictionnary = palette_dictionnary

    def __call__(self, pil_image):
        # avoid taking the alpha channel
        image_array = np.array(pil_image)[:, :, :3]
        # get only one channel for the output
        output_array = np.zeros(image_array.shape, dtype="int")[:, :, 0]

        for label in self.palette_dictionnary.keys():
            rgb_color = self.palette_dictionnary[label]['color']
            mask = (image_array == rgb_color)
            output_array[mask[:, :, 0]] = int(label)

        output_array = torch.from_numpy(output_array).unsqueeze(0).long()
        return output_array

Dataset

In [13]:
import os
import random
from glob import glob

import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from PIL import Image
from torch.utils.data.dataloader import default_collate
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF

In [14]:
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset

In [15]:
def show(imgs):
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = transforms.ToPILImage()(img.to('cpu').float())
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    plt.show()

In [16]:
class AutoFocusDataset(Dataset):
    """
        Dataset class for the AutoFocus Task. Requires for each image, its depth ground-truth and
        segmentation mask
        Args:
            :- config -: json config file
            :- dataset_name -: str
            :- split -: split ['train', 'val', 'test']
    """
    def __init__(self, config, dataset_name, split=None):
        self.split = split
        self.config = config

        path_images = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_images'])
        path_depths = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_depths'])
        path_segmentations = os.path.join(config['Dataset']['paths']['path_dataset'], dataset_name, config['Dataset']['paths']['path_segmentations'])
        
        self.paths_images = get_total_paths(path_images, config['Dataset']['extensions']['ext_images'])
        self.paths_depths = get_total_paths(path_depths, config['Dataset']['extensions']['ext_depths'])
        self.paths_segmentations = get_total_paths(path_segmentations, config['Dataset']['extensions']['ext_segmentations'])
        
        assert (self.split in ['train', 'test', 'val']), "Invalid split!"
        assert (len(self.paths_images) == len(self.paths_depths)), "Different number of instances between the input and the depth maps"
        assert (len(self.paths_images) == len(self.paths_segmentations)), "Different number of instances between the input and the segmentation maps"
        assert (config['Dataset']['splits']['split_train']+config['Dataset']['splits']['split_test']+config['Dataset']['splits']['split_val'] == 1), "Invalid splits (sum must be equal to 1)"
        # check for segmentation

      
        # utility func for splitting
        self.paths_images, self.paths_depths, self.paths_segmentations = get_splitted_dataset(config, self.split, dataset_name, self.paths_images, self.paths_depths, self.paths_segmentations)

        #----------------------------------------------------------------------------------------------------------------------------
        # Get the transforms
        self.transform_image, self.transform_depth, self.transform_seg = get_transforms(config)

        # get p_flip from config
        self.p_flip = config['Dataset']['transforms']['p_flip'] if split=='train' else 0
        self.p_crop = config['Dataset']['transforms']['p_crop'] if split=='train' else 0
        self.p_rot = config['Dataset']['transforms']['p_rot'] if split=='train' else 0
        self.resize = config['Dataset']['transforms']['resize']
        
    def __len__(self):
        """
            Function to get the number of images using the given list of images
        """
        return len(self.paths_images)
    
    def __getitem__(self, idx):
        """
            Getter function in order to get the triplet of images / depth maps and segmentation masks
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = self.transform_image(Image.open(self.paths_images[idx]))
        depth = self.transform_depth(Image.open(self.paths_depths[idx]))
        segmentation = self.transform_seg(Image.open(self.paths_segmentations[idx]))
        imgorig = image.clone()

        if random.random() < self.p_flip:
            image = TF.hflip(image)
            depth = TF.hflip(depth)
            segmentation = TF.hflip(segmentation)

        if random.random() < self.p_crop:
            random_size = random.randint(128, self.resize-1)
            max_size = self.resize - random_size
            left = int(random.random()*max_size)
            top = int(random.random()*max_size)
            image = TF.crop(image, top, left, random_size, random_size)
            depth = TF.crop(depth, top, left, random_size, random_size)
            segmentation = TF.crop(segmentation, top, left, random_size, random_size)
            image = transforms.Resize((self.resize, self.resize))(image)
            depth = transforms.Resize((self.resize, self.resize))(depth)
            segmentation = transforms.Resize((self.resize, self.resize), interpolation=transforms.InterpolationMode.NEAREST)(segmentation)

        if random.random() < self.p_rot:
            #rotate
            random_angle = random.random()*20 - 10 #[-10 ; 10]
            mask = torch.ones((1,self.resize,self.resize)) #useful for the resize at the end
            mask = TF.rotate(mask, random_angle, interpolation=transforms.InterpolationMode.BILINEAR)
            image = TF.rotate(image, random_angle, interpolation=transforms.InterpolationMode.BILINEAR)
            depth = TF.rotate(depth, random_angle, interpolation=transforms.InterpolationMode.BILINEAR)
            segmentation = TF.rotate(segmentation, random_angle, interpolation=transforms.InterpolationMode.NEAREST)
            #crop to remove black borders due to the rotation
            left = torch.argmax(mask[:,0,:]).item()
            top = torch.argmax(mask[:,:,0]).item()
            coin = min(left,top)
            size = self.resize - 2*coin
            image = TF.crop(image, coin, coin, size, size)
            depth = TF.crop(depth, coin, coin, size, size)
            segmentation = TF.crop(segmentation, coin, coin, size, size)
            #Resize
            image = transforms.Resize((self.resize, self.resize))(image)
            depth = transforms.Resize((self.resize, self.resize))(depth)
            segmentation = transforms.Resize((self.resize, self.resize), interpolation=transforms.InterpolationMode.NEAREST)(segmentation)
        # show([imgorig, image, depth, segmentation])
        # exit(0)
        return image, depth, segmentation

In [17]:
list_data = config['Dataset']['paths']['list_datasets']
list_data

['inria-fod', 'nyuv2-fod', 'posetrack-fod']

In [18]:
autofocus_datasets_train = [] #include 3 dataset 
for dataset_name in list_data:
     autofocus_datasets_train.append(AutoFocusDataset(config, dataset_name, 'train'))
train_data = ConcatDataset(autofocus_datasets_train)
train_dataloader = DataLoader(train_data, batch_size=config['General']['batch_size'], shuffle=True)
len(train_data)

1512

In [19]:
autofocus_datasets_test = [] #include 3 dataset 
for dataset_name in list_data:
     autofocus_datasets_test.append(AutoFocusDataset(config, dataset_name, 'test'))
test_data = ConcatDataset(autofocus_datasets_train)
test_dataloader = DataLoader(test_data, batch_size=config['General']['batch_size'], shuffle=True)
len(test_data)

1512

In [20]:
autofocus_datasets_val = []
for dataset_name in list_data:
    autofocus_datasets_val.append(AutoFocusDataset(config, dataset_name, 'val'))
val_data = ConcatDataset(autofocus_datasets_val)
val_dataloader = DataLoader(val_data, batch_size=config['General']['batch_size'], shuffle=True)
len(val_data)

504

Loss function

In [21]:
import torch
import torch.nn as nn

In [22]:
def compute_scale_and_shift(prediction, target, mask):
    # system matrix: A = [[a_00, a_01], [a_10, a_11]]
    a_00 = torch.sum(mask * prediction * prediction, (1, 2))
    a_01 = torch.sum(mask * prediction, (1, 2))
    a_11 = torch.sum(mask, (1, 2))

    # right hand side: b = [b_0, b_1]
    b_0 = torch.sum(mask * prediction * target, (1, 2))
    b_1 = torch.sum(mask * target, (1, 2))

    # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
    x_0 = torch.zeros_like(b_0)
    x_1 = torch.zeros_like(b_1)

    det = a_00 * a_11 - a_01 * a_01
    valid = det.nonzero()

    x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
    x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]

    return x_0, x_1

In [23]:
def reduction_batch_based(image_loss, M):
    # average of all valid pixels of the batch

    # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)
    divisor = torch.sum(M)

    if divisor == 0:
        return 0
    else:
        return torch.sum(image_loss) / divisor

In [24]:
def reduction_image_based(image_loss, M):
    # mean of average of valid pixels of an image

    # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)
    valid = M.nonzero()

    image_loss[valid] = image_loss[valid] / M[valid]

    return torch.mean(image_loss)

In [25]:
def mse_loss(prediction, target, mask, reduction=reduction_batch_based):

    M = torch.sum(mask, (1, 2))
    res = prediction - target
    image_loss = torch.sum(mask * res * res, (1, 2))

    return reduction(image_loss, 2 * M)

In [26]:
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based):

    M = torch.sum(mask, (1, 2))

    diff = prediction - target
    diff = torch.mul(mask, diff)

    grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
    mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
    grad_x = torch.mul(mask_x, grad_x)

    grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
    mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
    grad_y = torch.mul(mask_y, grad_y)

    image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))

    return reduction(image_loss, M)

In [27]:
class ScaleAndShiftInvariantLoss(nn.Module):
    def __init__(self, alpha=0.5, scales=4, reduction='batch-based'):
        super().__init__()

        self.__data_loss = MSELoss(reduction=reduction)
        self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)
        self.__alpha = alpha

        self.__prediction_ssi = None

    def forward(self, prediction, target):
        #preprocessing
        mask = target > 0

        #calcul
        scale, shift = compute_scale_and_shift(prediction, target, mask)
        # print(scale, shift)
        self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)

        total = self.__data_loss(self.__prediction_ssi, target, mask)
        if self.__alpha > 0:
            total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)

        return total

    def __get_prediction_ssi(self):
        return self.__prediction_ssi

    prediction_ssi = property(__get_prediction_ssi)

In [28]:
class MSELoss(nn.Module):
    def __init__(self, reduction='batch-based'):
        super().__init__()

        if reduction == 'batch-based':
            self.__reduction = reduction_batch_based
        else:
            self.__reduction = reduction_image_based

    def forward(self, prediction, target, mask):
        return mse_loss(prediction, target, mask, reduction=self.__reduction)

In [29]:
class GradientLoss(nn.Module):
    def __init__(self, scales=4, reduction='batch-based'):
        super().__init__()

        if reduction == 'batch-based':
            self.__reduction = reduction_batch_based
        else:
            self.__reduction = reduction_image_based

        self.__scales = scales

    def forward(self, prediction, target, mask):
        total = 0

        for scale in range(self.__scales):
            step = pow(2, scale)

            total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],
                                   mask[:, ::step, ::step], reduction=self.__reduction)

        return total

In [30]:
import os, errno
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from glob import glob
from PIL import Image
from torchvision import transforms, utils

In [31]:
def get_losses(config):
    def NoneFunction(a, b):
        return 0
    loss_depth = NoneFunction
    loss_segmentation = NoneFunction
    type = config['General']['type']
    if type == "full" or type=="depth":
        if config['General']['loss_depth'] == 'mse':
            loss_depth = nn.MSELoss()
        elif config['General']['loss_depth'] == 'ssi':
            loss_depth = ScaleAndShiftInvariantLoss()
    if type == "full" or type=="segmentation":
        if config['General']['loss_segmentation'] == 'ce':
            loss_segmentation = nn.CrossEntropyLoss()
    return loss_depth, loss_segmentation

In [32]:
def get_optimizer(config, net):
    names = set([name.split('.')[0] for name, _ in net.named_modules()]) - set(['', 'backbone'])
    params_backbone = net.backbone.parameters()
    params_scratch = list()
    for name in names:
        params_scratch += list(eval("net."+name).parameters())

    if config['General']['optim'] == 'adam':
        optimizer_backbone = optim.Adam(params_backbone, lr=config['General']['lr_backbone'])
        optimizer_scratch = optim.Adam(params_scratch, lr=config['General']['lr_scratch'])
    elif config['General']['optim'] == 'sgd':
        optimizer_backbone = optim.SGD(params_backbone, lr=config['General']['lr_backbone'], momentum=config['General']['momentum'])
        optimizer_scratch = optim.SGD(params_scratch, lr=config['General']['lr_scratch'], momentum=config['General']['momentum'])
    return optimizer_backbone, optimizer_scratch

In [33]:
def get_schedulers(optimizers):
    return [ReduceLROnPlateau(optimizer) for optimizer in optimizers]

In [34]:
def create_dir(directory):
    try:
        os.makedirs(directory)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

trainer

In [35]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import wandb
import cv2
import torch.nn as nn

from tqdm import tqdm
from os import replace
from numpy.core.numeric import Inf

In [77]:
class Trainer(object):
    def __init__(self, config, pruner=None):
        super().__init__()
        self.config = config
        self.type = self.config['General']['type']
        self.pruner = pruner

        self.device = torch.device(self.config['General']['device'] if torch.cuda.is_available() else "cpu")
        print("device: %s" % self.device)
        resize = config['Dataset']['transforms']['resize']
        
        self.model = config['General']['model']
        
        self.loss_depth, self.loss_segmentation = get_losses(config)
        self.optimizer_backbone, self.optimizer_scratch = get_optimizer(config, self.model)
        self.schedulers = get_schedulers([self.optimizer_backbone, self.optimizer_scratch])
        
    def train(self, train_dataloader, val_dataloader):
        epochs = self.config['General']['epochs']
        if self.config['wandb']['enable']:
            wandb.init(project="FocusOnDepth", entity=self.config['wandb']['username'])
            wandb.config = {
                "learning_rate_backbone": self.config['General']['lr_backbone'],
                "learning_rate_scratch": self.config['General']['lr_scratch'],
                "epochs": epochs,
                "batch_size": self.config['General']['batch_size']
            }
        val_loss = Inf
        for epoch in range(epochs):  # loop over the dataset multiple times
            print("Epoch ", epoch+1)
            running_loss = 0.0
            self.model.train()
            pbar = tqdm(train_dataloader)
            pbar.set_description("Training")
            for i, (X, Y_depths, Y_segmentations) in enumerate(pbar):
                X, Y_depths, Y_segmentations = X.to(self.device), Y_depths.to(self.device), Y_segmentations.to(self.device)
                self.optimizer_backbone.zero_grad()
                self.optimizer_scratch.zero_grad()
                output_depths = self.model(X)
    
                output_depths = output_depths['predicted_depth'].squeeze(1) if output_depths != None else None

                Y_depths = Y_depths.squeeze(1) #1xHxW -> HxW
                loss = self.loss_depth(output_depths, Y_depths)
                loss.backward()
                # step optimizer
                self.optimizer_scratch.step()
                self.optimizer_backbone.step()
                
                # Apply pruner to keep the model sparse
                if self.pruner != None:
                    self.pruner.apply(self.model)

                running_loss += loss.item()
                if np.isnan(running_loss):
                    print('\n',
                        X.min().item(), X.max().item(),'\n',
                        Y_depths.min().item(), Y_depths.max().item(),'\n',
                        output_depths.min().item(), output_depths.max().item(),'\n',
                        loss.item(),
                    )
                    exit(0)
                

                if self.config['wandb']['enable'] and ((i % 50 == 0 and i>0) or i==len(train_dataloader)-1):
                    wandb.log({"loss": running_loss/(i+1)})
                pbar.set_postfix({'training_loss': running_loss/(i+1)})

            new_val_loss = self.run_eval(val_dataloader)

            if new_val_loss < val_loss:
                self.save_model()
                val_loss = new_val_loss

            self.schedulers[0].step(new_val_loss)
            self.schedulers[1].step(new_val_loss)

        print('Finished Training')
        
    def run_eval(self, val_dataloader):
        """
            Evaluate the model on the validation set and visualize some results
            on wandb
            :- val_dataloader -: torch dataloader
        """
        val_loss = 0.
        self.model.eval()
        X_1 = None
        Y_depths_1 = None
        Y_segmentations_1 = None
        output_depths_1 = None
        output_segmentations_1 = None
        with torch.no_grad():
            pbar = tqdm(val_dataloader)
            pbar.set_description("Validation")
            for i, (X, Y_depths, Y_segmentations) in enumerate(pbar):
                X, Y_depths, Y_segmentations = X.to(self.device), Y_depths.to(self.device), Y_segmentations.to(self.device)
                output_depths = self.model(X)
                output_depths = output_depths['predicted_depth'].squeeze(1) if output_depths != None else None
                Y_depths = Y_depths.squeeze(1)
                if i==0:
                    X_1 = X
                    Y_depths_1 = Y_depths
                    output_depths_1 = output_depths

                loss = self.loss_depth(output_depths, Y_depths)
                val_loss += loss.item()
                pbar.set_postfix({'validation_loss': val_loss/(i+1)})
            if self.config['wandb']['enable']:
                wandb.log({"val_loss": val_loss/(i+1)})
                self.img_logger(X_1, Y_depths_1, Y_segmentations_1, output_depths_1, output_segmentations_1)
        return val_loss/(i+1)
    
    def save_model(self):
        path_model = os.path.join(self.config['General']['path_model'], self.model.__class__.__name__)
        create_dir(path_model)
        torch.save(self.model.state_dict(), path_model+'.pt')
        print('Model saved at : {}'.format(path_model))

    
    def img_logger(self, X, Y_depths, Y_segmentations, output_depths, output_segmentations):
        nb_to_show = self.config['wandb']['images_to_show'] if self.config['wandb']['images_to_show'] <= len(X) else len(X)
        tmp = X[:nb_to_show].detach().cpu().numpy()
        imgs = (tmp - tmp.min()) / (tmp.max() - tmp.min())
        if output_depths != None:
            tmp = Y_depths[:nb_to_show].unsqueeze(1).detach().cpu().numpy()
            depth_truths = np.repeat(tmp, 3, axis=1)
            tmp = output_depths[:nb_to_show].unsqueeze(1).detach().cpu().numpy()
            tmp = np.repeat(tmp, 3, axis=1)
            #depth_preds = 1.0 - tmp
            depth_preds = tmp
        if output_segmentations != None:
            tmp = Y_segmentations[:nb_to_show].unsqueeze(1).detach().cpu().numpy()
            segmentation_truths = np.repeat(tmp, 3, axis=1).astype('float32')
            tmp = torch.argmax(output_segmentations[:nb_to_show], dim=1)
            tmp = tmp.unsqueeze(1).detach().cpu().numpy()
            tmp = np.repeat(tmp, 3, axis=1)
            segmentation_preds = tmp.astype('float32')
        imgs = imgs.transpose(0,2,3,1)
        if output_depths != None:
            depth_truths = depth_truths.transpose(0,2,3,1)
            depth_preds = depth_preds.transpose(0,2,3,1)
        if output_segmentations != None:
            segmentation_truths = segmentation_truths.transpose(0,2,3,1)
            segmentation_preds = segmentation_preds.transpose(0,2,3,1)
        output_dim = (int(self.config['wandb']['im_w']), int(self.config['wandb']['im_h']))

        wandb.log({
            "img": [wandb.Image(cv2.resize(im, output_dim), caption='img_{}'.format(i+1)) for i, im in enumerate(imgs)]
        })
        if output_depths != None:
            wandb.log({
                "depth_truths": [wandb.Image(cv2.resize(im, output_dim), caption='depth_truths_{}'.format(i+1)) for i, im in enumerate(depth_truths)],
                "depth_preds": [wandb.Image(cv2.resize(im, output_dim), caption='depth_preds_{}'.format(i+1)) for i, im in enumerate(depth_preds)]
            })
        if output_segmentations != None:
            wandb.log({
                "seg_truths": [wandb.Image(cv2.resize(im, output_dim), caption='seg_truths_{}'.format(i+1)) for i, im in enumerate(segmentation_truths)],
                "seg_preds": [wandb.Image(cv2.resize(im, output_dim), caption='seg_preds_{}'.format(i+1)) for i, im in enumerate(segmentation_preds)]
            })

In [37]:
def eval_fn(model, val_dataloader):
        """
            Evaluate the model on the validation set and visualize some results
            on wandb
            :- model -: torch model using for evaluate
            :- val_dataloader -: torch dataloader
        """
        loss_depth, loss_segmentation = get_losses(config)
        device = torch.device(config['General']['device'] if torch.cuda.is_available() else "cpu")
        
        val_loss = 0.
        model.eval()
        
        #
        X_1 = None 
        Y_depths_1 = None
        output_depths_1 = None
        
        #
        with torch.no_grad():
            for i, (X, Y_depths, _) in enumerate(val_dataloader):
                X, Y_depths = X.to(device), Y_depths.to(device)
                output_depths = model(X)
                output_depths = output_depths['predicted_depth'].squeeze(1) if output_depths != None else None
                Y_depths = Y_depths.squeeze(1)
                if i==0:
                    X_1 = X
                    Y_depths_1 = Y_depths
                    output_depths_1 = output_depths
                    
                # get loss
                loss = loss_depth(output_depths, Y_depths)
                val_loss += loss.item()
#                 pbar.set_postfix({'validation_loss': val_loss/(i+1)})
        return val_loss/(i+1)

In [38]:
model_path = "/kaggle/input/weight-dpt-swin2-tiny-256-ssiloss/models/DPTForDepthEstimation-ssi.pt"

checkpoint = torch.load(model_path)

In [39]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [40]:
model.eval()

DPTForDepthEstimation(
  (backbone): Swinv2Backbone(
    (embeddings): Swinv2Embeddings(
      (patch_embeddings): Swinv2PatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Swinv2Encoder(
      (layers): ModuleList(
        (0): Swinv2Stage(
          (blocks): ModuleList(
            (0-1): 2 x Swinv2Layer(
              (attention): Swinv2Attention(
                (self): Swinv2SelfAttention(
                  (continuous_position_bias_mlp): Sequential(
                    (0): Linear(in_features=2, out_features=512, bias=True)
                    (1): ReLU(inplace=True)
                    (2): Linear(in_features=512, out_features=3, bias=False)
                  )
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, 

In [41]:
recover_model = lambda : model.load_state_dict(checkpoint)

Helper Functions (Flops, Model Size calculation, etc.)

In [42]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)


def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [43]:
dense_model_accuracy = eval_fn(model, test_dataloader)
dense_model_size = get_model_size(model)
print(f"dense model has ssi loss={dense_model_accuracy:.2f}")
print(f"dense model has size={dense_model_size/MiB:.2f} MiB")



dense model has ssi loss=0.03
dense model has size=156.14 MiB


# fine_grained_prune

In [44]:
def fine_grained_prune(tensor: torch.Tensor, sparsity : float) -> torch.Tensor:
    """
    magnitude-based pruning for single tensor
    :param tensor: torch.(cuda.)Tensor, weight of conv/fc layer
    :param sparsity: float, pruning sparsity
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    :return:
        torch.(cuda.)Tensor, mask for zeros
    """
    sparsity = min(max(0.0, sparsity), 1.0)
    if sparsity == 1.0:
        tensor.zero_()
        return torch.zeros_like(tensor)
    elif sparsity == 0.0:
        return torch.ones_like(tensor)
    
    num_elements = tensor.numel()
    #
    # Step 1: calculate the #zeros (please use round())
    num_zeros = round(num_elements * sparsity)
    
    # Step 2: calculate the importance of weight
    importance = torch.abs(tensor)
    
    # Step 3: calculate the pruning threshold
    threshold = torch.kthvalue(importance.flatten(), num_zeros)[0]
    
    # Step 4: get binary mask (1 for nonzeros, 0 for zeros)
    mask = importance > threshold
    #
    tensor.mul_(mask)

    return mask

In [45]:
class FineGrainedPruner:
    def __init__(self, model, sparsity_dict):
        self.masks = FineGrainedPruner.prune(model, sparsity_dict)

    @torch.no_grad()
    def apply(self, model):
        for name, param in model.named_parameters():
            if name in self.masks:
                param *= self.masks[name]

    @staticmethod
    @torch.no_grad()
    def prune(model, sparsity_dict):
        masks = dict()
        for name, param in model.named_parameters():
            if param.dim() > 3: # we only prune conv and fc weights
                masks[name] = fine_grained_prune(param, sparsity_dict[name])
        return masks

sparsity dict

In [46]:
sparsity_dict = {
    'backbone.embeddings.patch_embeddings.projection.weight': 0,
    'neck.convs.0.weight':0.9,
    'neck.convs.1.weight':0.9,
    'neck.convs.2.weight':0.9,
    'neck.convs.3.weight':0.8,
    'neck.fusion_stage.layers.0.projection.weight':0.8,
    'neck.fusion_stage.layers.0.residual_layer1.convolution1.weight':0.9,
    'neck.fusion_stage.layers.0.residual_layer1.convolution2.weight':0.9,
    'neck.fusion_stage.layers.0.residual_layer2.convolution1.weight':0.7,
    'neck.fusion_stage.layers.0.residual_layer2.convolution2.weight':0.9,
    'neck.fusion_stage.layers.1.projection.weight':0.7,
    'neck.fusion_stage.layers.1.residual_layer1.convolution1.weight':0.9,
    'neck.fusion_stage.layers.1.residual_layer1.convolution2.weight':0.8,
    'neck.fusion_stage.layers.1.residual_layer2.convolution1.weight':0.9,
    'neck.fusion_stage.layers.1.residual_layer2.convolution2.weight':0.9,
    'neck.fusion_stage.layers.2.projection.weight':0.9,
    'neck.fusion_stage.layers.2.residual_layer1.convolution1.weight':0.9,
    'neck.fusion_stage.layers.2.residual_layer1.convolution2.weight':0.7,
    'neck.fusion_stage.layers.2.residual_layer2.convolution1.weight':0.9,
    'neck.fusion_stage.layers.2.residual_layer2.convolution2.weight':0.9,
    'neck.fusion_stage.layers.3.projection.weight':0.9,
    'neck.fusion_stage.layers.3.residual_layer1.convolution1.weight':0.8,
    'neck.fusion_stage.layers.3.residual_layer1.convolution2.weight':0.9,
    'neck.fusion_stage.layers.3.residual_layer2.convolution1.weight':0.9,
    'neck.fusion_stage.layers.3.residual_layer2.convolution2.weight':0.6,
    'head.head.0.weight':0.7,
    'head.head.2.weight':0.7,
    'head.head.4.weight':0.6
}

In [47]:
pruner = FineGrainedPruner(model, sparsity_dict)

In [48]:
print(f'After pruning with sparsity dictionary')
for name, sparsity in sparsity_dict.items():
    print(f'  {name}: {sparsity:.2f}')
print(f'The sparsity of each layer becomes')
for name, param in model.named_parameters():
    if name in sparsity_dict:
        print(f'  {name}: {get_sparsity(param):.2f}')

sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = eval_fn(model, test_dataloader)
print(f"Sparse model has ssi loss ={sparse_model_accuracy:.2f} before fintuning")

After pruning with sparsity dictionary
  backbone.embeddings.patch_embeddings.projection.weight: 0.00
  neck.convs.0.weight: 0.90
  neck.convs.1.weight: 0.90
  neck.convs.2.weight: 0.90
  neck.convs.3.weight: 0.80
  neck.fusion_stage.layers.0.projection.weight: 0.80
  neck.fusion_stage.layers.0.residual_layer1.convolution1.weight: 0.90
  neck.fusion_stage.layers.0.residual_layer1.convolution2.weight: 0.90
  neck.fusion_stage.layers.0.residual_layer2.convolution1.weight: 0.70
  neck.fusion_stage.layers.0.residual_layer2.convolution2.weight: 0.90
  neck.fusion_stage.layers.1.projection.weight: 0.70
  neck.fusion_stage.layers.1.residual_layer1.convolution1.weight: 0.90
  neck.fusion_stage.layers.1.residual_layer1.convolution2.weight: 0.80
  neck.fusion_stage.layers.1.residual_layer2.convolution1.weight: 0.90
  neck.fusion_stage.layers.1.residual_layer2.convolution2.weight: 0.90
  neck.fusion_stage.layers.2.projection.weight: 0.90
  neck.fusion_stage.layers.2.residual_layer1.convolution1.w

## Finetune the fine-grained pruned model

In [49]:
pruner.masks.keys()

dict_keys(['backbone.embeddings.patch_embeddings.projection.weight', 'neck.convs.0.weight', 'neck.convs.1.weight', 'neck.convs.2.weight', 'neck.convs.3.weight', 'neck.fusion_stage.layers.0.projection.weight', 'neck.fusion_stage.layers.0.residual_layer1.convolution1.weight', 'neck.fusion_stage.layers.0.residual_layer1.convolution2.weight', 'neck.fusion_stage.layers.0.residual_layer2.convolution1.weight', 'neck.fusion_stage.layers.0.residual_layer2.convolution2.weight', 'neck.fusion_stage.layers.1.projection.weight', 'neck.fusion_stage.layers.1.residual_layer1.convolution1.weight', 'neck.fusion_stage.layers.1.residual_layer1.convolution2.weight', 'neck.fusion_stage.layers.1.residual_layer2.convolution1.weight', 'neck.fusion_stage.layers.1.residual_layer2.convolution2.weight', 'neck.fusion_stage.layers.2.projection.weight', 'neck.fusion_stage.layers.2.residual_layer1.convolution1.weight', 'neck.fusion_stage.layers.2.residual_layer1.convolution2.weight', 'neck.fusion_stage.layers.2.residua

In [50]:
trainer = Trainer(config, pruner)
trainer.train(train_dataloader, val_dataloader)

device: cuda
Epoch  1


Training: 100%|██████████| 1512/1512 [03:18<00:00,  7.61it/s, training_loss=0.0378]
Validation: 100%|██████████| 504/504 [00:48<00:00, 10.44it/s, validation_loss=0.0334]


Model saved at : models/DPTForDepthEstimation
Epoch  2


Training: 100%|██████████| 1512/1512 [03:19<00:00,  7.57it/s, training_loss=0.0318]
Validation: 100%|██████████| 504/504 [00:37<00:00, 13.34it/s, validation_loss=0.0322]


Model saved at : models/DPTForDepthEstimation
Epoch  3


Training: 100%|██████████| 1512/1512 [03:18<00:00,  7.63it/s, training_loss=0.0305]
Validation: 100%|██████████| 504/504 [00:37<00:00, 13.58it/s, validation_loss=0.0307]


Model saved at : models/DPTForDepthEstimation
Finished Training


In [51]:
sparse_model_size = get_model_size(model, count_nonzero_only=True)
print(f"Sparse model has size={sparse_model_size / MiB:.2f} MiB = {sparse_model_size / dense_model_size * 100:.2f}% of dense model size")
sparse_model_accuracy = eval_fn(model, test_dataloader)
print(f"Sparse model has ssi loss={sparse_model_accuracy:.2f}% after fintuning")

Sparse model has size=113.34 MiB = 72.59% of dense model size
Sparse model has ssi loss=0.03% after fintuning


# Quantization

In [52]:
!pip install -qqq fast-pytorch-kmeans

In [53]:
import copy 

In [54]:
from collections import namedtuple
    
Codebook = namedtuple('Codebook', ['centroids', 'labels'])

In [55]:
from fast_pytorch_kmeans import KMeans

In [56]:
def k_means_quantize(fp32_tensor: torch.Tensor, bitwidth=4, codebook=None):
    """
    quantize tensor using k-means clustering
    :param fp32_tensor:
    :param bitwidth: [int] quantization bit width, default=4
    :param codebook: [Codebook] (the cluster centroids, the cluster label tensor)
    :return:
        [Codebook = (centroids, labels)]
            centroids: [torch.(cuda.)FloatTensor] the cluster centroids
            labels: [torch.(cuda.)LongTensor] cluster label tensor
    """
    if codebook is None:
        n_clusters = 2**bitwidth
        kmeans = KMeans(n_clusters=n_clusters, mode='euclidean', verbose=0)
        labels = kmeans.fit_predict(fp32_tensor.view(-1, 1)).to(torch.long)
        centroids = kmeans.centroids.to(torch.float).view(-1)
        codebook = Codebook(centroids, labels)
    quantized_tensor = codebook.centroids[codebook.labels]
    fp32_tensor.set_(quantized_tensor.view_as(fp32_tensor))
    return codebook

Quantize in whole model

In [57]:
class KMeansQuantizer:
    def __init__(self, model : nn.Module, bitwidth=4):
        self.codebook = KMeansQuantizer.quantize(model, bitwidth)
        
    @torch.no_grad()
    def apply(self, model, update_centroids):
        for name, param in model.named_parameters():
            if name in self.codebook:
                if update_centroids:
                    update_codebook(param, codebook=self.codebook[name])
                self.codebook[name] = k_means_quantize(
                    param, codebook=self.codebook[name])
                
    @staticmethod
    @torch.no_grad()
    def quantize(model: nn.Module, bitwidth=4):
        codebook = dict()
        if isinstance(bitwidth, dict):
            for name, param in model.named_parameters():
                if name in bitwidth:
                    codebook[name] = k_means_quantize(param, bitwidth=bitwidth[name])
        else:
            for name, param in model.named_parameters():
                # print(f"{name=}: {param.shape}")
                # only quantize weight, not bias
                if param.dim() > 1:
                    codebook[name] = k_means_quantize(param, bitwidth=bitwidth)
        return codebook

In [70]:
quantizers = dict()

for bitwidth in [8, 4, 2]:
    # Recover model
    model_copy = copy.deepcopy(model)
    print(f'k-means quantizing model into {bitwidth} bits')
    quantizer = KMeansQuantizer(model_copy, bitwidth) # codebook
    quantized_model_size = get_model_size(model_copy, bitwidth, True)
    print(f"    {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB")
    quantized_model_accuracy = eval_fn(model_copy, test_dataloader)
    print(f"    {bitwidth}-bit k-means quantized model has ssi loss={quantized_model_accuracy:.2f}")
    quantizers[bitwidth] = quantizer

k-means quantizing model into 8 bits
    8-bit k-means quantized model has size=37.86 MiB
    8-bit k-means quantized model has ssi loss=0.03
k-means quantizing model into 4 bits
    4-bit k-means quantized model has size=19.39 MiB
    4-bit k-means quantized model has ssi loss=0.08
k-means quantizing model into 2 bits
    2-bit k-means quantized model has size=9.76 MiB
    2-bit k-means quantized model has ssi loss=0.11


In [74]:
quantizers

{8: <__main__.KMeansQuantizer at 0x7c205fba1420>,
 4: <__main__.KMeansQuantizer at 0x7c20b84692d0>,
 2: <__main__.KMeansQuantizer at 0x7c206c4d4340>}

## Trained K Mean quantization

In [71]:
def update_codebook(fp32_tensor: torch.Tensor, codebook: Codebook):
    """
    update the centroids in the codebook using updated fp32_tensor
    :param fp32_tensor: [torch.(cuda.)Tensor]
    :param codebook: [Codebook] (the cluster centroids, the cluster label tensor)
    """
    n_clusters = codebook.centroids.numel()
    fp32_tensor = fp32_tensor.view(-1)
    for k in range(n_clusters):
        codebook.centroids[k] = torch.mean(fp32_tensor[codebook.labels==k])

In [72]:
fp32_model_accuracy = eval_fn(model, test_dataloader)
print(f"fp32 model has ssi loss={fp32_model_accuracy:.2f}")

fp32 model has ssi loss=0.03


In [75]:
accuracy_drop_threshold = 0.02
quantizers_before_finetune = copy.deepcopy(quantizers)
quantizers_after_finetune = quantizers

In [78]:
for bitwidth in [8, 4, 2]:
    model_copy = copy.deepcopy(model)
    quantizer = quantizers[bitwidth]
    print(f'k-means quantizing model into {bitwidth} bits')
    quantizer.apply(model_copy, update_centroids=False)
    quantized_model_size = get_model_size(model_copy, bitwidth, True)
    print(f"    {bitwidth}-bit k-means quantized model has size={quantized_model_size/MiB:.2f} MiB")
    quantized_model_accuracy = eval_fn(model_copy, test_dataloader)
    print(f"    {bitwidth}-bit k-means quantized model has ssi loss={quantized_model_accuracy:.2f} before quantization-aware training ")
    accuracy_drop = abs(fp32_model_accuracy - quantized_model_accuracy)
    if accuracy_drop > accuracy_drop_threshold:
        print(f"        Quantization-aware training due to ssi loss drop={accuracy_drop:.2f} is larger than threshold={accuracy_drop_threshold:.2f}")
        config["General"]["model"] = model_copy
        trainer = Trainer(config)
        trainer.train(train_dataloader, val_dataloader)
        quantized_model_accuracy = eval_fn(model_copy, test_dataloader)
        print(f"    {bitwidth}-bit k-means quantized model has ssi loss={quantized_model_accuracy:.2f} before quantization-aware training ")
    else:
        print(f"        No need for quantization-aware training since ssi loss drop={accuracy_drop:.2f} is smaller than threshold={accuracy_drop_threshold:.2f}")
     

k-means quantizing model into 8 bits
    8-bit k-means quantized model has size=37.86 MiB
    8-bit k-means quantized model has ssi loss=0.03 before quantization-aware training 
        No need for quantization-aware training since ssi loss drop=0.00 is smaller than threshold=0.02
k-means quantizing model into 4 bits
    4-bit k-means quantized model has size=19.39 MiB
    4-bit k-means quantized model has ssi loss=0.08 before quantization-aware training 
        Quantization-aware training due to ssi loss drop=0.05 is larger than threshold=0.02
device: cuda
Epoch  1


Training: 100%|██████████| 1512/1512 [03:18<00:00,  7.63it/s, training_loss=0.0509]
Validation: 100%|██████████| 504/504 [00:38<00:00, 13.02it/s, validation_loss=0.0452]


Model saved at : models/DPTForDepthEstimation
Epoch  2


Training: 100%|██████████| 1512/1512 [03:19<00:00,  7.59it/s, training_loss=0.0425]
Validation: 100%|██████████| 504/504 [00:37<00:00, 13.54it/s, validation_loss=0.0411]


Model saved at : models/DPTForDepthEstimation
Epoch  3


Training: 100%|██████████| 1512/1512 [03:18<00:00,  7.60it/s, training_loss=0.0397]
Validation: 100%|██████████| 504/504 [00:37<00:00, 13.46it/s, validation_loss=0.0404]


Model saved at : models/DPTForDepthEstimation
Finished Training
    4-bit k-means quantized model has ssi loss=0.04 before quantization-aware training 
k-means quantizing model into 2 bits
    2-bit k-means quantized model has size=9.76 MiB
    2-bit k-means quantized model has ssi loss=0.11 before quantization-aware training 
        Quantization-aware training due to ssi loss drop=0.09 is larger than threshold=0.02
device: cuda
Epoch  1


Training: 100%|██████████| 1512/1512 [03:30<00:00,  7.19it/s, training_loss=0.201]
Validation: 100%|██████████| 504/504 [00:39<00:00, 12.86it/s, validation_loss=0.207]


Model saved at : models/DPTForDepthEstimation
Epoch  2


Training: 100%|██████████| 1512/1512 [03:34<00:00,  7.06it/s, training_loss=0.2]  
Validation: 100%|██████████| 504/504 [00:41<00:00, 12.14it/s, validation_loss=0.207]


Epoch  3


Training: 100%|██████████| 1512/1512 [03:40<00:00,  6.84it/s, training_loss=0.201]
Validation: 100%|██████████| 504/504 [00:42<00:00, 11.80it/s, validation_loss=0.207]


Finished Training
    2-bit k-means quantized model has ssi loss=0.20 before quantization-aware training 
