In [2]:
import random
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import wandb
from torch.utils.data import DataLoader
from torchmetrics import MeanAbsoluteError, MeanSquaredError, R2Score
from tqdm.auto import tqdm

In [3]:
%load_ext autoreload
%autoreload 2

from src.model.CNN_regression import (
    DBHDepthDataset,
    DBHRegressor,
    evaluate,
    train_one_epoch,
)

In [4]:
# list timm models
import timm

timm.list_models("resnet*")

['resnet10t',
 'resnet14t',
 'resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet26t',
 'resnet32ts',
 'resnet33ts',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50_clip',
 'resnet50_clip_gap',
 'resnet50_gn',
 'resnet50_mlp',
 'resnet50c',
 'resnet50d',
 'resnet50s',
 'resnet50t',
 'resnet50x4_clip',
 'resnet50x4_clip_gap',
 'resnet50x16_clip',
 'resnet50x16_clip_gap',
 'resnet50x64_clip',
 'resnet50x64_clip_gap',
 'resnet51q',
 'resnet61q',
 'resnet101',
 'resnet101_clip',
 'resnet101_clip_gap',
 'resnet101c',
 'resnet101d',
 'resnet101s',
 'resnet152',
 'resnet152c',
 'resnet152d',
 'resnet152s',
 'resnet200',
 'resnet200d',
 'resnetaa34d',
 'resnetaa50',
 'resnetaa50d',
 'resnetaa101d',
 'resnetblur18',
 'resnetblur50',
 'resnetblur50d',
 'resnetblur101d',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_18',
 'resnetv2_18d',
 'resnetv2_34',
 'resnetv2_34d',
 'resnetv2_50',
 'resnetv2_50d',
 'resne

In [5]:
timm.list_models("vit*")

['vit_7b_patch16_dinov3',
 'vit_base_mci_224',
 'vit_base_patch8_224',
 'vit_base_patch14_dinov2',
 'vit_base_patch14_reg4_dinov2',
 'vit_base_patch16_18x2_224',
 'vit_base_patch16_224',
 'vit_base_patch16_224_miil',
 'vit_base_patch16_384',
 'vit_base_patch16_clip_224',
 'vit_base_patch16_clip_384',
 'vit_base_patch16_clip_quickgelu_224',
 'vit_base_patch16_dinov3',
 'vit_base_patch16_dinov3_qkvb',
 'vit_base_patch16_gap_224',
 'vit_base_patch16_plus_240',
 'vit_base_patch16_plus_clip_240',
 'vit_base_patch16_reg4_gap_256',
 'vit_base_patch16_rope_224',
 'vit_base_patch16_rope_ape_224',
 'vit_base_patch16_rope_mixed_224',
 'vit_base_patch16_rope_mixed_ape_224',
 'vit_base_patch16_rope_reg1_gap_256',
 'vit_base_patch16_rpn_224',
 'vit_base_patch16_siglip_224',
 'vit_base_patch16_siglip_256',
 'vit_base_patch16_siglip_384',
 'vit_base_patch16_siglip_512',
 'vit_base_patch16_siglip_gap_224',
 'vit_base_patch16_siglip_gap_256',
 'vit_base_patch16_siglip_gap_384',
 'vit_base_patch16_siglip

In [6]:
import gc

torch.cuda.empty_cache()
gc.collect()

38

In [7]:
@dataclass
class Config:
    epochs: int = 300
    batch_size: int = 32
    lr: float = 1e-3
    image_size: int = 192
    num_workers: int = 4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    log_every_n_batches: int = 4

    base_path: str = "../dataset/DepthMapDBH2023/"
    segmentation_model_name: str = "DA3_LARGE"
    models = [
        "resnet50"
    ]
    
cfg = Config()
seed = random.randint(0, 10_000)
torch.manual_seed(seed)
np.random.seed(seed)

# Load datasets using CSV files
train_csv = Path(cfg.base_path) / "train/train/files_with_depth_maps_DA3_LARGE.csv"
test_csv = Path(cfg.base_path) / "test/test/files_with_depth_maps_DA3_LARGE.csv"

full_dataset = DBHDepthDataset(train_csv, cfg.base_path)
test_dataset = DBHDepthDataset(test_csv, cfg.base_path)

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

In [8]:
train_ds, val_ds = torch.utils.data.random_split(
        full_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(seed),
)

train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    pin_memory=True,
)

val_loader = DataLoader(
    val_ds,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
)

In [None]:
for backbone in cfg.models:
    run_name = (
        f"{backbone}_depth_dbh_{cfg.segmentation_model_name}_{datetime.now():%Y%m%d_%H%M}"
    )

    run = wandb.init(
        project="DBH-Depth-Map-CNN-Regression",
        name=run_name,
        config=vars(cfg),
    )

    model = DBHRegressor(backbone).to(cfg.device)

    optimizer = optim.AdamW(model.parameters(), lr=cfg.lr)
    loss_fn = nn.MSELoss()

    metrics = {
        "rmse": MeanSquaredError(squared=False).to(cfg.device),
        "mae": MeanAbsoluteError().to(cfg.device),
        "r2": R2Score().to(cfg.device),
    }

    best_val = float("inf")
    patience, wait = 26, 0

    for epoch in tqdm(range(cfg.epochs)):
        train_loss = train_one_epoch(
            model, train_loader, optimizer, loss_fn, cfg.device, epoch, cfg.log_every_n_batches
        )

        val_metrics = evaluate(model, val_loader, loss_fn, metrics, cfg.device)

        wandb.log(
            {
                "epoch": epoch,
                "train/loss": train_loss,
                "val/loss": val_metrics["loss"],
                "val/rmse": val_metrics["rmse"],
                "val/mae": val_metrics["mae"],
                "val/r2": val_metrics["r2"],
            }
        )

        print(f"[{epoch:03d}] train={train_loss:.4f} val_rmse={val_metrics['rmse']:.3f}")

        if val_metrics["loss"] < best_val:
            best_val = val_metrics["loss"]
            wait = 0

            ckpt_path = f"{run_name}_best.pt"
            torch.save(model.state_dict(), ckpt_path)

            artifact = wandb.Artifact(
                name=f"{backbone}-dbh-regressor",
                type="model",
                metadata={
                    "backbone": backbone,
                    "segmentation_model": cfg.segmentation_model_name,
                    "epoch": epoch,
                    "val_loss": best_val,
                },
            )

            artifact.add_file(ckpt_path, overwrite=True)
            run.log_artifact(artifact)

    # Test
    model.load_state_dict(torch.load(f"{run_name}_best.pt"))
    test_metrics = evaluate(model, test_loader, loss_fn, metrics, cfg.device)

    wandb.log({f"test_{k}": v for k, v in test_metrics.items()})
    wandb.finish()

print("All trainings done!")

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
[34m[1mwandb[0m: Currently logged in as: [33mmicrohum[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/300 [00:00<?, ?it/s]

Train:   0%|          | 0/26 [00:52<?, ?it/s]

Eval:   0%|          | 0/7 [00:53<?, ?it/s]

[000] train=1330.0266 val_rmse=40.066


Train:   0%|          | 0/26 [00:54<?, ?it/s]

Eval:   0%|          | 0/7 [00:56<?, ?it/s]

[001] train=416.4346 val_rmse=40.169


Train:   0%|          | 0/26 [00:53<?, ?it/s]

Eval:   0%|          | 0/7 [01:55<?, ?it/s]

[002] train=263.2625 val_rmse=33.952


Train:   0%|          | 0/26 [01:04<?, ?it/s]

Eval:   0%|          | 0/7 [01:03<?, ?it/s]

[003] train=195.4542 val_rmse=37.051


Train:   0%|          | 0/26 [01:03<?, ?it/s]

Eval:   0%|          | 0/7 [01:03<?, ?it/s]

[004] train=156.8458 val_rmse=31.377


Train:   0%|          | 0/26 [01:01<?, ?it/s]