# Imports and preparing

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
torch.cuda.empty_cache()

In [3]:
import random
random.seed(42)

In [4]:
from datetime import datetime
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

In [5]:
import logging
logging.basicConfig(filename=f"{timestamp}.log", encoding='utf-8', level=logging.DEBUG)

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

fileHandler = logging.FileHandler(f"{timestamp}.log")
fileHandler.setLevel(logging.DEBUG)

consoleHandler = logging.StreamHandler()
consoleHandler.setLevel(logging.DEBUG)

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

fileHandler.setFormatter(formatter)
consoleHandler.setFormatter(formatter)

logger.addHandler(fileHandler)
logger.addHandler(consoleHandler)

In [6]:
import os
import multiprocessing
import torch.nn as nn
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
from flash.core.optimizers import LinearWarmupCosineAnnealingLR
from PIL import Image
from transformers import AutoImageProcessor, DPTForDepthEstimation, get_scheduler, Dinov2Config, DPTConfig
from tqdm.auto import tqdm

# Constants

In [7]:
IMAGE_SIZE = (256, 255)

In [8]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"using device {DEVICE}")

2024-01-15 09:47:40,682 - __main__ - INFO - using device cuda:0


In [9]:
HEAD_TYPE = "scratch"

In [10]:
BACKBONE_TYPE = 'small' # in ("small", "base", "large" or "giant")

In [11]:
RESUME_EPOCH = 0

In [12]:
RESUME_PATH = "HF-model_small_scratch_20240113_103956_102"
if RESUME_PATH is not None:
    BACKBONE_TYPE = RESUME_PATH.split('/')[-1].split('_')[1]
    HEAD_TYPE = RESUME_PATH.split('/')[-1].split('_')[2]
    RESUME_EPOCH = int(RESUME_PATH.split('/')[-1].split('_')[-1])
    STATE_PATH = RESUME_PATH.split('-')[-1]

In [13]:
CONFIG_NAME = f"facebook/dinov2-{BACKBONE_TYPE}"

In [14]:
TRAIN_PATH = "/home/jovyan/work/saved_data/data/thumbnails/test"
TEST_PATH = "/home/jovyan/work/saved_data/data/thumbnails/train"

In [15]:
EPOCHS = 500 - RESUME_EPOCH
BATCH_SIZE = 32
EVAL_INTERVAL = 1

In [16]:
START_LR = 1e-8
WARMUP_LR = 1e-9
END_LR = 1e-4

In [17]:
TRAIN_BACKBONE = False
TRAIN_NECK = True
TRAIN_HEAD = True

In [18]:
HIDDEN_DROP = 0
ATTENTION_DROP = 0

# Classes

In [19]:
class CustomNPZDataset(Dataset):
    def __init__(self, path, image_processor, transform=None, image_transforms=None):
        self.path = path
        self.files = list(Path(path).glob('*.npz'))
        self.image_processor = image_processor
        self.transform = transform
        self.image_transforms = image_transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        with np.load(str(self.files[item])) as data:
            X_numpy = data['X']
            y_numpy = data['y']
        X_torch = torch.from_numpy(X_numpy)
        y_torch = torch.from_numpy(y_numpy).unsqueeze(0)
        if self.transform is not None:
            X_torch = self.transform(X_torch)
            y_torch = self.transform(y_torch)
        if self.image_transforms is not None:
            X_torch = self.transform(X_torch)
        return X_torch, y_torch

In [20]:
class PaperSigLoss(nn.Module):
    def __init__(self, valid_mask=True, a=10, l=0.85):
        super(PaperSigLoss, self).__init__()
        
        self.valid_mask = valid_mask

        self.alpha = a
        self.lamb = l

        self.eps = 0.001  # avoid grad explode
        
    def paperSigloss(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            input = input[valid_mask]
            target = target[valid_mask]

        delta = torch.log(target + self.eps) - torch.log(input + self.eps)
        loss = torch.mean(delta ** 2) - self.lamb / ((torch.numel(delta)) ** 2) * (torch.sum(delta) ** 2) 
        return loss

    def forward(self, depth_pred, depth_gt):
        loss_depth = self.paperSigloss(depth_pred, depth_gt)
        return self.alpha * loss_depth

In [21]:
class SigLoss(nn.Module):
    def __init__(
        self, valid_mask=True, max_depth=None):
        super(SigLoss, self).__init__()
        
        self.valid_mask = valid_mask
        self.max_depth = max_depth

        self.eps = 0.001  # avoid grad explode

    def sigloss(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            if self.max_depth is not None:
                valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
            input = input[valid_mask]
            target = target[valid_mask]

        g = torch.log(input + self.eps) - torch.log(target + self.eps)
        Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
        return torch.sqrt(Dg)

    def forward(self, depth_pred, depth_gt):
        loss_depth = self.sigloss(depth_pred, depth_gt)
        return loss_depth

In [22]:
class MaskedMAE(nn.Module):
    def __init__(self, valid_mask=True, max_depth=None):
        super(MaskedMAE, self).__init__()
        
        self.valid_mask = valid_mask
        self.max_depth = max_depth

    def mae(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            if self.max_depth is not None:
                valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
            input = input[valid_mask]
            target = target[valid_mask]

        mae = torch.abs(input - target).mean()
        return mae
    
    def forward(self, depth_pred, depth_gt):
        metric_mae = self.mae(depth_pred, depth_gt)
        return metric_mae

In [23]:
class MaskedR2Score(nn.Module):
    def __init__(self, valid_mask=True, max_depth=None):
        super(MaskedR2Score, self).__init__()

        self.valid_mask = valid_mask
        self.max_depth = max_depth

    def r2(self, input, target):
        if self.valid_mask:
            valid_mask = target > 0
            if self.max_depth is not None:
                valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
            input = input[valid_mask]
            target = target[valid_mask]

        mean_target = torch.mean(target)
        ss_total = torch.sum((target - mean_target)**2)
        ss_residual = torch.sum((input - target)**2)

        r2 = 1 - (ss_residual / ss_total)
        return r2
    
    def forward(self, depth_pred, depth_gt):
        metric_r2 = self.r2(depth_pred, depth_gt)
        return metric_r2

# Initiallization

In [24]:
image_processor = AutoImageProcessor.from_pretrained(CONFIG_NAME)

backbone_config = Dinov2Config.from_pretrained(
    CONFIG_NAME, 
    out_features=["stage1", "stage2", "stage3", "stage4"], 
    reshape_hidden_states=False,
    hidden_dropout_prob=HIDDEN_DROP,
    attention_probs_dropout_prob=ATTENTION_DROP
)
config = DPTConfig(
    backbone_config=backbone_config,
    hidden_dropout_prob=HIDDEN_DROP,
    attention_probs_dropout_prob=ATTENTION_DROP
)

if RESUME_PATH is None:
    model = DPTForDepthEstimation(config=config)
else:
    model = DPTForDepthEstimation.from_pretrained(RESUME_PATH)

In [25]:
model = model.cuda()

In [26]:
augmentation_transform = transforms.Compose([
    # transforms.RandomVerticalFlip(),
    # transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(90)#,
    # transforms.RandomResizedCrop((IMAGE_SIZE[0], IMAGE_SIZE[1]), scale=(0.8, 1.0))
])

In [27]:
img_only_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, contrast=0.1)
])

In [28]:
train_dataset = CustomNPZDataset(path=TRAIN_PATH, image_processor=image_processor)#, transform=augmentation_transform, image_transforms=img_only_transforms)
validation_dataset = CustomNPZDataset(path=TEST_PATH, image_processor=image_processor)

In [29]:
training_loader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE, num_workers=8, pin_memory=True)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, num_workers=8, pin_memory=True)

In [30]:
optimizer = AdamW(model.parameters(), lr=START_LR)

In [31]:
total_steps = EPOCHS * len(training_loader)
logger.info(f"{total_steps} Total Steps")
lr_scheduler = LinearWarmupCosineAnnealingLR(optimizer, 12000, total_steps, warmup_start_lr=WARMUP_LR, eta_min=END_LR)

2024-01-15 09:47:42,153 - __main__ - INFO - 55720 Total Steps


In [32]:
last_step = RESUME_EPOCH * len(training_loader)
logger.info(f"Continueing at step {last_step}")
for _ in range(last_step):
    lr_scheduler.step()

2024-01-15 09:47:42,160 - __main__ - INFO - Continueing at step 14280


In [33]:
loss_fn = PaperSigLoss()

In [34]:
mae_fn = MaskedMAE()
r2_fn = MaskedR2Score()

# Helper Functions

In [35]:
def run_model(data):
    inputs, labels = data

    images = [Image.fromarray(input.numpy().transpose(1, 2, 0)) for input in inputs]

    inputs = image_processor(images=images, return_tensors="pt")

    inputs = inputs.to(DEVICE)
    
    outputs = model(**inputs)

    predicted_depth = outputs['predicted_depth']

    predictions = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=IMAGE_SIZE,
        mode="bicubic",
        align_corners=False,
    )

    labels = labels.to(DEVICE)
    
    loss = loss_fn(predictions, labels)

    mae = mae_fn(predictions, labels)
    r2 = r2_fn(predictions, labels)
    
    return loss, mae, r2

In [36]:
def train_one_epoch():
    running_loss = 0.
    epoch_loss = 0.

    running_mae = 0.
    running_r2 = 0.
    epoch_mae = 0.
    epoch_r2 = 0.

    for i, data in tqdm(enumerate(training_loader), total=len(training_loader)):
        loss, mae, r2 = run_model(data)

        if loss.isnan().all():
            logger.error("Exploding Gradeints!")
            raise Exception('Exploding Gradients!')

        optimizer.zero_grad()
        
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1e-3)

        optimizer.step()

        lr_scheduler.step()

        running_loss += loss.item()

        running_mae += mae.item()
        running_r2 += r2.item()

    epoch_loss = running_loss / len(training_loader)

    epoch_mae = running_mae / len(training_loader)
    epoch_r2 = running_r2 / len(training_loader)
    
    return epoch_loss, epoch_mae, epoch_r2

In [37]:
# lr_scheduler.get_lr()

# Training Loop

In [38]:
writer = SummaryWriter(f"runs/dinov2_{BACKBONE_TYPE}_dpt_{HEAD_TYPE}_{timestamp}")

epoch_number = 1 + RESUME_EPOCH

best_vloss = 1_000_000.

for epoch in range(RESUME_EPOCH, EPOCHS+RESUME_EPOCH):
    logger.info('EPOCH {}:'.format(epoch_number))

    model.train(True)
    
    for param in model.backbone.parameters():
        param.requires_grad = TRAIN_BACKBONE
    for param in model.neck.parameters():
        param.requires_grad = TRAIN_NECK
    for param in model.head.parameters():
        param.requires_grad = TRAIN_HEAD
        
    avg_loss, avg_mae, avg_r2 = train_one_epoch()

    logger.info('LOSS Train {}'.format(avg_loss))

    logger.info('MAE Train {}'.format(avg_mae))
    logger.info('R2 Train {}'.format(avg_r2))

    writer.add_scalars('Training Loss',
                        { 'Training' : avg_loss },
                        epoch_number)
    writer.add_scalars('Training MAE',
                    { 'Training' : avg_mae },
                    epoch_number)
    writer.add_scalars('Training R2',
                    { 'Training' : avg_r2 },
                    epoch_number)

    if epoch_number % EVAL_INTERVAL == 0:
        running_vloss = 0.0

        running_vmae = 0.0
        running_vr2 = 0.0
    
        model.eval()
    
        with torch.no_grad():
            for i, vdata in tqdm(enumerate(validation_loader), total=len(validation_loader)):
                vloss, vmae, vr2 = run_model(vdata)
                running_vloss += vloss
                running_vmae += vmae
                running_vr2 += vr2
    
        avg_vloss = running_vloss / len(validation_loader)
        logger.info('LOSS valid {}'.format(avg_vloss))
    
        avg_vmae = running_vmae / len(validation_loader)
        logger.info('MAE valid {}'.format(avg_vmae))
        avg_vr2 = running_vr2 / len(validation_loader)
        logger.info('R2 valid {}'.format(avg_vr2))
    
        writer.add_scalars('Validation Loss',
                        { 'Validation' : avg_vloss },
                        epoch_number)
        writer.add_scalars('Validation MAE',
                        { 'Validation' : avg_vmae },
                        epoch_number)
        writer.add_scalars('Validation R2',
                        { 'Validation' : avg_vr2 },
                        epoch_number)

        model_path = 'model_{}_{}_{}_{}'.format(BACKBONE_TYPE, HEAD_TYPE, timestamp, epoch_number)
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            model_path = 'model_{}_{}_{}_{}_best'.format(BACKBONE_TYPE, HEAD_TYPE, timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
        if not os.path.exists(f"HF-{model_path}"):
            os.mkdir(f"HF-{model_path}")
        model.save_pretrained(f"HF-{model_path}")
    
    writer.flush()
    
    epoch_number += 1

2024-01-15 09:47:42,245 - __main__ - INFO - EPOCH 103:


  0%|          | 0/140 [00:00<?, ?it/s]

2024-01-15 09:49:10,839 - __main__ - INFO - LOSS Train 4.996352878638676
2024-01-15 09:49:10,840 - __main__ - INFO - MAE Train 2.5307401376111165
2024-01-15 09:49:10,841 - __main__ - INFO - R2 Train 0.8337895806346621


  0%|          | 0/120 [00:00<?, ?it/s]

2024-01-15 09:50:08,890 - __main__ - INFO - LOSS valid 18.485225677490234
2024-01-15 09:50:08,893 - __main__ - INFO - MAE valid 4.455069541931152
2024-01-15 09:50:08,895 - __main__ - INFO - R2 valid 0.4657798409461975
2024-01-15 09:50:09,221 - __main__ - INFO - EPOCH 104:


  0%|          | 0/140 [00:00<?, ?it/s]

2024-01-15 09:51:37,220 - __main__ - INFO - LOSS Train 4.891311999729702
2024-01-15 09:51:37,221 - __main__ - INFO - MAE Train 2.4929042220115663
2024-01-15 09:51:37,222 - __main__ - INFO - R2 Train 0.836650635940688


  0%|          | 0/120 [00:00<?, ?it/s]

2024-01-15 09:52:35,478 - __main__ - INFO - LOSS valid 18.459365844726562
2024-01-15 09:52:35,480 - __main__ - INFO - MAE valid 4.470583438873291
2024-01-15 09:52:35,483 - __main__ - INFO - R2 valid 0.4617373049259186
2024-01-15 09:52:35,806 - __main__ - INFO - EPOCH 105:


  0%|          | 0/140 [00:00<?, ?it/s]

2024-01-15 09:53:38,654 - __main__ - ERROR - Exploding Gradeints!


Exception: Exploding Gradients!