In [1]:
import logging
import os
import sys
import traceback

import torch

from hydra import compose, initialize_config_module
from hydra.utils import instantiate

from omegaconf import OmegaConf

from training.utils.train_utils import makedir, register_omegaconf_resolvers

os.environ["HYDRA_FULL_ERROR"] = "1"

In [2]:
def single_proc_run(local_rank, main_port, cfg, world_size):
    """Single GPU process"""
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(main_port)
    os.environ["RANK"] = str(local_rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    try:
        register_omegaconf_resolvers()
    except Exception as e:
        logging.info(e)

    trainer = instantiate(cfg.trainer, _recursive_=False)
    trainer.run()


def single_node_runner(cfg, main_port: int):

    # CUDA runtime does not support `fork`
    torch.multiprocessing.set_start_method("spawn")

    single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=1)


def format_exception(e: Exception, limit=20):
    traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
    return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"


def add_pythonpath_to_sys_path():
    if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
        return
    sys.path = os.environ["PYTHONPATH"].split(":") + sys.path



In [3]:
initialize_config_module("sam2", version_base="1.2")
register_omegaconf_resolvers()

In [4]:
cfg = compose(config_name="configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml")

In [5]:
cfg.trainer.data.train.datasets[0].dataset.datasets[0].video_dataset._target_ = 'training.dataset.vos_raw_dataset.JSONRawDataset'

In [6]:
# Customize the config
cfg.scratch.max_num_objects = 3
cfg.scratch.num_epochs = 20
cfg.launcher.gpus_per_node = 1
cfg.launcher.num_nodes = 1
cfg.dataset.img_folder = "/home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/images"
cfg.dataset.gt_folder = "/home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/annotations"
cfg.dataset.file_list_txt = "/home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/list_files.txt"
cfg.trainer.checkpoint.model_weight_initializer.state_dict.checkpoint_path = (
    "/home/kasm-user/sam2_ft_runpod/checkpoints/sam2.1_hiera_base_plus.pt"
)

In [7]:
if cfg.launcher.experiment_log_dir is None:
    cfg.launcher.experiment_log_dir = os.path.join(
        os.getcwd(), "sam2_logs", "experiment_log_dir"
    )

In [8]:
print("###################### Train App Config ####################")
print(OmegaConf.to_yaml(cfg))
print("############################################################")

###################### Train App Config ####################
scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 8
  max_num_objects: 3
  base_lr: 5.0e-06
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 20
dataset:
  img_folder: /home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/images
  gt_folder: /home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/annotations
  file_list_txt: /home/kasm-user/sam2_ft_runpod/mini_dataset_sav1/list_files.txt
  multiplier: 2
vos:
  train_transforms:
  - _target_: training.dataset.transforms.ComposeAPI
    transforms:
    - _target_: training.dataset.transforms.RandomHorizontalFlip
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomAffine
      degrees: 25
      shear: 20
      image_interpolation: bilinear
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomResizeAPI
      sizes: ${scratch.resolution}
      square: true
      consistent_transform: true


In [9]:
add_pythonpath_to_sys_path()
makedir(cfg.launcher.experiment_log_dir)

True

In [10]:
single_node_runner(cfg, 4500)

INFO 2025-02-12 16:03:46,184 train_utils.py: 108: MACHINE SEED: 2460
INFO 2025-02-12 16:03:46,190 train_utils.py: 154: Logging ENV_VARIABLES
INFO 2025-02-12 16:03:46,191 train_utils.py: 155: AUDIO_PORT=4901
CLICOLOR=1
CLICOLOR_FORCE=1
COLORTERM=truecolor
CONDA_DEFAULT_ENV=sam2_ft
CONDA_EXE=/home/kasm-user/miniconda3/bin/conda
CONDA_PREFIX=/home/kasm-user/miniconda3/envs/sam2_ft
CONDA_PREFIX_1=/home/kasm-user/miniconda3
CONDA_PROMPT_MODIFIER=(sam2_ft) 
CONDA_PYTHON_EXE=/home/kasm-user/miniconda3/bin/python
CONDA_SHLVL=2
CUDA_MODULE_LOADING=LAZY
DBUS_SESSION_BUS_ADDRESS=unix:abstract=/tmp/dbus-TtTD6QIbJZ,guid=045a9a40a2789f9fbf65f0bb67acc207
DEBIAN_FRONTEND=noninteractive
DESKTOP_SESSION=xfce
DISPLAY=:1.0
DISTRO=ubuntu
FORCE_COLOR=1
GIT_PAGER=cat
GOMP_SPINCOUNT=0
HOME=/home/kasm-user
HOSTNAME=33986542191c
HYDRA_FULL_ERROR=1
INST_SCRIPTS=/dockerstartup/install
JPY_PARENT_PID=5979
JPY_SESSION_NAME=/home/kasm-user/sam2_ft_runpod/training/train.ipynb
JUPYTER_PASSWORD=7jkicez5h5lgfzt8nkm7
KAS

grad.sizes() = [64, 256, 1, 1], strides() = [256, 1, 256, 256]
bucket_view.sizes() = [64, 256, 1, 1], strides() = [256, 1, 1, 1] (Triggered internally at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:327.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


INFO 2025-02-12 16:04:04,723 trainer.py: 950: Estimated time remaining: 00d 00h 04m
INFO 2025-02-12 16:04:04,726 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:04:04,727 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 1.0697195511311293, 'Losses/train_all_loss_mask': 0.007477112259948626, 'Losses/train_all_loss_dice': 0.6235996270552278, 'Losses/train_all_loss_iou': 0.29358856566250324, 'Losses/train_all_loss_class': 0.002989097949466668, 'Losses/train_all_core_loss': 1.0697195511311293, 'Trainer/where': 0.04375, 'Trainer/epoch': 0, 'Trainer/steps_train': 8}
INFO 2025-02-12 16:04:11,348 train_utils.py: 271: Train Epoch: [1][0/8] | Batch Time: 5.57 (5.57) | Data Time: 4.69 (4.69) | Mem (GB): 38.00 (38.00/38.00) | Time Elapsed: 00d 00h 00m | Losses/train_all_loss: 1.49e-01 (1.49e-01)
INFO 2025-02-12 16:04:18,595 trainer.py: 950: Estimated time remaining: 00d 00h 03m
INFO 2025-02-12 16:04:18,597 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:04:18,598 tra



INFO 2025-02-12 16:04:55,455 train_utils.py: 271: Train Epoch: [4][0/8] | Batch Time: 5.77 (5.77) | Data Time: 4.63 (4.63) | Mem (GB): 42.00 (42.00/42.00) | Time Elapsed: 00d 00h 01m | Losses/train_all_loss: 9.10e-01 (9.10e-01)
INFO 2025-02-12 16:05:03,538 trainer.py: 950: Estimated time remaining: 00d 00h 03m
INFO 2025-02-12 16:05:03,540 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:05:03,541 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 1.44469029083848, 'Losses/train_all_loss_mask': 0.010664786474080756, 'Losses/train_all_loss_dice': 0.7243481408804655, 'Losses/train_all_loss_iou': 0.45657253614626825, 'Losses/train_all_loss_class': 0.05047395192377735, 'Losses/train_all_core_loss': 1.44469029083848, 'Trainer/where': 0.24375, 'Trainer/epoch': 4, 'Trainer/steps_train': 40}
INFO 2025-02-12 16:05:10,682 train_utils.py: 271: Train Epoch: [5][0/8] | Batch Time: 5.95 (5.95) | Data Time: 5.24 (5.24) | Mem (GB): 28.00 (28.00/28.00) | Time Elapsed: 00d 00h 01m | Lo



INFO 2025-02-12 16:05:56,002 train_utils.py: 271: Train Epoch: [8][0/8] | Batch Time: 5.62 (5.62) | Data Time: 4.71 (4.71) | Mem (GB): 39.00 (39.00/39.00) | Time Elapsed: 00d 00h 02m | Losses/train_all_loss: 4.49e-01 (4.49e-01)
INFO 2025-02-12 16:06:03,662 trainer.py: 950: Estimated time remaining: 00d 00h 02m
INFO 2025-02-12 16:06:03,665 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:06:03,666 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 0.5974392648786306, 'Losses/train_all_loss_mask': 0.004840199399041012, 'Losses/train_all_loss_dice': 0.3695075288414955, 'Losses/train_all_loss_iou': 0.12935322988778353, 'Losses/train_all_loss_class': 0.0017744944634614512, 'Losses/train_all_core_loss': 0.5974392648786306, 'Trainer/where': 0.44375, 'Trainer/epoch': 8, 'Trainer/steps_train': 72}
INFO 2025-02-12 16:06:12,204 train_utils.py: 271: Train Epoch: [9][0/8] | Batch Time: 7.33 (7.33) | Data Time: 6.22 (6.22) | Mem (GB): 43.00 (43.00/43.00) | Time Elapsed: 00d 00h 02



INFO 2025-02-12 16:07:25,337 train_utils.py: 271: Train Epoch: [14][0/8] | Batch Time: 5.77 (5.77) | Data Time: 4.65 (4.65) | Mem (GB): 43.00 (43.00/43.00) | Time Elapsed: 00d 00h 03m | Losses/train_all_loss: 1.00e+00 (1.00e+00)
INFO 2025-02-12 16:07:32,668 trainer.py: 950: Estimated time remaining: 00d 00h 01m
INFO 2025-02-12 16:07:32,670 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:07:32,671 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 0.9413489010185003, 'Losses/train_all_loss_mask': 0.005147426534676924, 'Losses/train_all_loss_dice': 0.4273555427789688, 'Losses/train_all_loss_iou': 0.26627630763687193, 'Losses/train_all_loss_class': 0.14476855689645163, 'Losses/train_all_core_loss': 0.9413489010185003, 'Trainer/where': 0.74375, 'Trainer/epoch': 14, 'Trainer/steps_train': 120}




INFO 2025-02-12 16:07:39,455 train_utils.py: 271: Train Epoch: [15][0/8] | Batch Time: 5.61 (5.61) | Data Time: 4.58 (4.58) | Mem (GB): 42.00 (42.00/42.00) | Time Elapsed: 00d 00h 03m | Losses/train_all_loss: 4.03e+00 (4.03e+00)
INFO 2025-02-12 16:07:46,867 trainer.py: 950: Estimated time remaining: 00d 00h 00m
INFO 2025-02-12 16:07:46,869 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:07:46,870 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 1.0949368737637997, 'Losses/train_all_loss_mask': 0.01823466735368129, 'Losses/train_all_loss_dice': 0.39740641601383686, 'Losses/train_all_loss_iou': 0.17691878648474813, 'Losses/train_all_loss_class': 0.15591835734812776, 'Losses/train_all_core_loss': 1.0949368737637997, 'Trainer/where': 0.79375, 'Trainer/epoch': 15, 'Trainer/steps_train': 128}
INFO 2025-02-12 16:07:53,736 train_utils.py: 271: Train Epoch: [16][0/8] | Batch Time: 5.62 (5.62) | Data Time: 4.49 (4.49) | Mem (GB): 43.00 (43.00/43.00) | Time Elapsed: 00d 00h 

In [11]:
print("Finito")

Finito
