<a href="https://colab.research.google.com/github/yujunyoung1107/GLPDepth_colab/blob/main/GLPDepth_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install tensorboardX
!pip install mmcv
!pip install timm
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
!git clone https://github.com/yujunyoung1107/GLPDepth_new.git

In [None]:
%cd /content/GLPDepth_new/datasets/
!unzip /content/gdrive/MyDrive/nyu_depth_v2.zip

In [None]:
%cd /content/GLPDepth_new/code/models/weights/
!cp /content/gdrive/MyDrive/mit_b4.pth mit_bt.pth
%cd /content/GLPDepth_new/code/

In [None]:
import os
import cv2
import numpy as np
from datetime import datetime
import torch
import torch.optim as optim
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
from tqdm.notebook import tqdm

from models.model import GLPDepth
import utils.metrics as metrics
from utils.criterion import SiLogLoss
import utils.logging as logging
from dataset.base_dataset import get_dataset
from configs.train_options import TrainOptions

In [None]:
%cd /content/GLPDepth_new/
metric_name = ['d1', 'd2', 'd3', 'abs_rel', 'sq_rel', 'rmse', 'rmse_log',
               'log10', 'silog']

model = GLPDepth(max_depth=80, is_train=True)
if torch.cuda.is_available():
    device = torch.device('cuda')
    cudnn.benchmark = True
    model = torch.nn.DataParallel(model)
else:
    device = torch.device('cpu')
print(device)

In [None]:
data_type = 'nyudepthv2'
data_path = '/content/GLPDepth_new/datasets/'

#-----------------------hyper parameters---------------------------#
batch_size = 6
num_workers = 2
lr = 1e-4
epochs = 25
max_dep = 10
min_dep = 1e-3
#------------------------------------------------------------------#

dataset_kwargs = {'dataset_name': data_type, 'data_path': data_path}
if data_type == 'nyudepthv2':
    dataset_kwargs['crop_size'] = (448, 576)
elif data_type == 'kitti':
    dataset_kwargs['crop_size'] = (352, 704)

train_dataset = get_dataset(**dataset_kwargs)
val_dataset = get_dataset(**dataset_kwargs, is_train=False)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                            shuffle=True, num_workers=num_workers, 
                                            pin_memory=True, drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False,
                                            pin_memory=True)

In [8]:
def train(train_loader, model, criterion_d, optimizer, device, epoch):    
    global global_step
    model.train()
    depth_loss = logging.AverageMeter()
    half_epoch = epochs // 2

    with tqdm(train_loader, unit='batch') as pbar:
        pbar.set_description(f'Epoch {epoch}')
        for batch_idx, batch in enumerate(pbar):      
            global_step += 1

            for param_group in optimizer.param_groups:
                if global_step < 2019 * half_epoch:
                    current_lr = (1e-4 - 3e-5) * (global_step /
                                                2019/half_epoch) ** 0.9 + 3e-5
                else:
                    current_lr = (3e-5 - 1e-4) * (global_step /
                                                2019/half_epoch - 1) ** 0.9 + 1e-4
                param_group['lr'] = current_lr

            input_RGB = batch['image'].to(device)
            depth_gt = batch['depth'].to(device)

            preds = model(input_RGB)

            optimizer.zero_grad()
            loss_d = criterion_d(preds['pred_d'].squeeze(), depth_gt)
            depth_loss.update(loss_d.item(), input_RGB.size(0))
            loss_d.backward()

            optimizer.step()

            pbar.set_postfix({'Loss' : depth_loss.val, 'Loss avg' : depth_loss.avg})

    return loss_d


In [9]:
def validate(val_loader, model, criterion_d, device, epoch, max_dep, min_dep):
    depth_loss = logging.AverageMeter()
    model.eval()

    torch.save(model.state_dict(), '/content/gdrive//MyDrive/epoch_%02d_model.ckpt' % epoch)

    result_metrics = {}
    for metric in metric_name:
        result_metrics[metric] = 0.0

    for batch_idx, batch in enumerate(val_loader):
        input_RGB = batch['image'].to(device)
        depth_gt = batch['depth'].to(device)
        filename = batch['filename'][0]

        with torch.no_grad():
            preds = model(input_RGB)

        pred_d = preds['pred_d'].squeeze()
        depth_gt = depth_gt.squeeze()

        loss_d = criterion_d(preds['pred_d'].squeeze(), depth_gt)

        depth_loss.update(loss_d.item(), input_RGB.size(0))

        pred_crop, gt_crop = metrics.cropping_img(pred_d, max_dep, min_dep, depth_gt, data_type)
        computed_result = metrics.eval_depth(pred_crop, gt_crop)

        if save_path.split('.')[-1] == 'jpg':
            save_path = save_path.replace('jpg', 'png')

        loss_d = depth_loss.avg
        logging.progress_bar(batch_idx, len(val_loader), epochs, epoch)

        for key in result_metrics.keys():
            result_metrics[key] += computed_result[key]

    for key in result_metrics.keys():
        result_metrics[key] = result_metrics[key] / (batch_idx + 1)

    return result_metrics, loss_d

In [None]:
# Training settings
criterion_d = SiLogLoss()
optimizer = optim.Adam(model.parameters(), lr)

global global_step
global_step = 0

for epoch in range(1, epochs + 1):
    print('\nEpoch: %03d - %03d' % (epoch, epochs))
    loss_train = train(train_loader, model, criterion_d, optimizer=optimizer, 
                        device=device, epoch=epoch)

    
    results_dict, loss_val = validate(val_loader, model, criterion_d, device, epoch, max_dep, min_dep)
    result_lines = logging.display_result(results_dict)
    print(result_lines)
