In [1]:
import logging
import os
import sys
import traceback

import torch

from hydra import compose, initialize_config_module
from hydra.utils import instantiate

from omegaconf import OmegaConf

from training.utils.train_utils import makedir, register_omegaconf_resolvers

os.environ["HYDRA_FULL_ERROR"] = "1"

In [2]:
def single_proc_run(local_rank, main_port, cfg, world_size):
    """Single GPU process"""
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(main_port)
    os.environ["RANK"] = str(local_rank)
    os.environ["LOCAL_RANK"] = str(local_rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    try:
        register_omegaconf_resolvers()
    except Exception as e:
        logging.info(e)

    trainer = instantiate(cfg.trainer, _recursive_=False)
    trainer.run()


def single_node_runner(cfg, main_port: int):

    # CUDA runtime does not support `fork`
    torch.multiprocessing.set_start_method("spawn")

    single_proc_run(local_rank=0, main_port=main_port, cfg=cfg, world_size=1)


def format_exception(e: Exception, limit=20):
    traceback_str = "".join(traceback.format_tb(e.__traceback__, limit=limit))
    return f"{type(e).__name__}: {e}\nTraceback:\n{traceback_str}"


def add_pythonpath_to_sys_path():
    if "PYTHONPATH" not in os.environ or not os.environ["PYTHONPATH"]:
        return
    sys.path = os.environ["PYTHONPATH"].split(":") + sys.path



In [3]:
initialize_config_module("sam2", version_base="1.2")
register_omegaconf_resolvers()

In [4]:
cfg = compose(config_name="configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml")

In [5]:
# cfg.trainer.data.train.datasets[0].dataset.datasets[0].video_dataset._target_ = 'training.dataset.vos_raw_dataset.JSONRawDataset'

In [6]:
# Customize the config
cfg.scratch.max_num_objects = 3
cfg.scratch.num_epochs = 20
cfg.launcher.gpus_per_node = 1
cfg.launcher.num_nodes = 1
cfg.dataset.img_folder = "/home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/images"
cfg.dataset.gt_folder = "/home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/annotations"
cfg.dataset.file_list_txt = "/home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/list_files.txt"
cfg.trainer.checkpoint.model_weight_initializer.state_dict.checkpoint_path = (
    "/home/kasm-user/sam2_ft_runpod/checkpoints/sam2.1_hiera_base_plus.pt"
)

In [7]:
if cfg.launcher.experiment_log_dir is None:
    cfg.launcher.experiment_log_dir = os.path.join(
        os.getcwd(), "sam2_logs", "experiment_log_dir"
    )

In [8]:
print("###################### Train App Config ####################")
print(OmegaConf.to_yaml(cfg))
print("############################################################")

###################### Train App Config ####################
scratch:
  resolution: 1024
  train_batch_size: 1
  num_train_workers: 10
  num_frames: 8
  max_num_objects: 3
  base_lr: 5.0e-06
  vision_lr: 3.0e-06
  phases_per_epoch: 1
  num_epochs: 20
dataset:
  img_folder: /home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/images
  gt_folder: /home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/annotations
  file_list_txt: /home/kasm-user/sam2_ft_runpod/prepped_mini_dataset_png/list_files.txt
  multiplier: 2
vos:
  train_transforms:
  - _target_: training.dataset.transforms.ComposeAPI
    transforms:
    - _target_: training.dataset.transforms.RandomHorizontalFlip
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomAffine
      degrees: 25
      shear: 20
      image_interpolation: bilinear
      consistent_transform: true
    - _target_: training.dataset.transforms.RandomResizeAPI
      sizes: ${scratch.resolution}
      square: true
      consis

In [9]:
add_pythonpath_to_sys_path()
makedir(cfg.launcher.experiment_log_dir)

True

In [10]:
single_node_runner(cfg, 4500)

INFO 2025-02-12 16:36:57,239 train_utils.py: 108: MACHINE SEED: 2460
INFO 2025-02-12 16:36:57,243 train_utils.py: 154: Logging ENV_VARIABLES
INFO 2025-02-12 16:36:57,244 train_utils.py: 155: AUDIO_PORT=4901
CLICOLOR=1
CLICOLOR_FORCE=1
COLORTERM=truecolor
CONDA_DEFAULT_ENV=sam2_ft
CONDA_EXE=/home/kasm-user/miniconda3/bin/conda
CONDA_PREFIX=/home/kasm-user/miniconda3/envs/sam2_ft
CONDA_PREFIX_1=/home/kasm-user/miniconda3
CONDA_PROMPT_MODIFIER=(sam2_ft) 
CONDA_PYTHON_EXE=/home/kasm-user/miniconda3/bin/python
CONDA_SHLVL=2
CUDA_MODULE_LOADING=LAZY
DBUS_SESSION_BUS_ADDRESS=unix:abstract=/tmp/dbus-oeGWpnCXQH,guid=7b0087612934a81b822e3a5467accc2f
DEBIAN_FRONTEND=noninteractive
DESKTOP_SESSION=xfce
DISPLAY=:1.0
DISTRO=ubuntu
FORCE_COLOR=1
GIT_PAGER=cat
GOMP_SPINCOUNT=0
HOME=/home/kasm-user
HOSTNAME=82995ba0f5a4
HYDRA_FULL_ERROR=1
INST_SCRIPTS=/dockerstartup/install
JPY_PARENT_PID=6294
JPY_SESSION_NAME=/home/kasm-user/sam2_ft_runpod/training_sav_png/train.ipynb
JUPYTER_PASSWORD=2bzb6g06hxv6fuvm

grad.sizes() = [64, 256, 1, 1], strides() = [256, 1, 256, 256]
bucket_view.sizes() = [64, 256, 1, 1], strides() = [256, 1, 1, 1] (Triggered internally at /pytorch/torch/csrc/distributed/c10d/reducer.cpp:327.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


INFO 2025-02-12 16:37:16,075 trainer.py: 950: Estimated time remaining: 00d 00h 05m
INFO 2025-02-12 16:37:16,079 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:37:16,081 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 53.974300384521484, 'Losses/train_all_loss_mask': 1.8846410202095285, 'Losses/train_all_loss_dice': 9.971903383731842, 'Losses/train_all_loss_iou': 1.898172228210359, 'Losses/train_all_loss_class': 4.411401152610779, 'Losses/train_all_core_loss': 53.974300384521484, 'Trainer/where': 0.04375, 'Trainer/epoch': 0, 'Trainer/steps_train': 8}
INFO 2025-02-12 16:37:23,257 train_utils.py: 271: Train Epoch: [1][0/8] | Batch Time: 5.80 (5.80) | Data Time: 4.87 (4.87) | Mem (GB): 38.00 (38.00/38.00) | Time Elapsed: 00d 00h 00m | Losses/train_all_loss: 1.77e+01 (1.77e+01)
INFO 2025-02-12 16:37:31,142 trainer.py: 950: Estimated time remaining: 00d 00h 03m
INFO 2025-02-12 16:37:31,145 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:37:31,146 trainer.py:



INFO 2025-02-12 16:37:38,345 train_utils.py: 271: Train Epoch: [2][0/8] | Batch Time: 5.77 (5.77) | Data Time: 4.67 (4.67) | Mem (GB): 42.00 (42.00/42.00) | Time Elapsed: 00d 00h 00m | Losses/train_all_loss: 2.26e+01 (2.26e+01)
INFO 2025-02-12 16:37:46,008 trainer.py: 950: Estimated time remaining: 00d 00h 03m
INFO 2025-02-12 16:37:46,010 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:37:46,011 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 14.919279918074608, 'Losses/train_all_loss_mask': 0.08977580085047521, 'Losses/train_all_loss_dice': 8.525644212961197, 'Losses/train_all_loss_iou': 1.6903906010120409, 'Losses/train_all_loss_class': 2.9077292159199715, 'Losses/train_all_core_loss': 14.919279918074608, 'Trainer/where': 0.14375, 'Trainer/epoch': 2, 'Trainer/steps_train': 24}
INFO 2025-02-12 16:37:55,030 train_utils.py: 271: Train Epoch: [3][0/8] | Batch Time: 7.63 (7.63) | Data Time: 6.54 (6.54) | Mem (GB): 42.00 (42.00/42.00) | Time Elapsed: 00d 00h 00m | Lo



INFO 2025-02-12 16:42:06,138 train_utils.py: 271: Train Epoch: [18][0/8] | Batch Time: 7.57 (7.57) | Data Time: 6.65 (6.65) | Mem (GB): 38.00 (38.00/38.00) | Time Elapsed: 00d 00h 05m | Losses/train_all_loss: 4.42e+00 (4.42e+00)
INFO 2025-02-12 16:42:13,965 trainer.py: 950: Estimated time remaining: 00d 00h 00m
INFO 2025-02-12 16:42:13,967 trainer.py: 892: Synchronizing meters
INFO 2025-02-12 16:42:13,968 trainer.py: 830: Losses and meters: {'Losses/train_all_loss': 10.042121917009354, 'Losses/train_all_loss_mask': 0.08896358752190281, 'Losses/train_all_loss_dice': 6.350963672623038, 'Losses/train_all_loss_iou': 0.8702553386101499, 'Losses/train_all_loss_class': 1.0416307719424367, 'Losses/train_all_core_loss': 10.042121917009354, 'Trainer/where': 0.94375, 'Trainer/epoch': 18, 'Trainer/steps_train': 152}
INFO 2025-02-12 16:42:21,481 train_utils.py: 271: Train Epoch: [19][0/8] | Batch Time: 6.11 (6.11) | Data Time: 4.77 (4.77) | Mem (GB): 43.00 (43.00/43.00) | Time Elapsed: 00d 00h 05m 

In [12]:
print("Finito")

Finito
