### Train an MMDetection Network
- See [tutorial](https://github.com/open-mmlab/mmdetection/blob/main/demo/MMDet_Tutorial.ipynb)

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [7]:
from pathlib import Path
import sys
from datetime import datetime

### Download checkpoint for a pretrained model (if desired)
Alternatively, use a previous mouse model as a pretrained model

In [9]:
pretrained_model_directory = Path("/n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models")
pretrained_model_directory.mkdir(parents=True, exist_ok=True)

In [10]:
# find models here: https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet
pretrain_model = "rtmdet_s_8xb32-300e_coco"

In [11]:
command = f"source activate {Path(sys.executable).parents[1]}; mim download mmdet --config {pretrain_model} --dest {pretrained_model_directory.as_posix()}"
print(command)

source activate /n/groups/datta/tim_sainburg/conda_envs/mmdeploy; mim download mmdet --config rtmdet_s_8xb32-300e_coco --dest /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models


In [12]:
!{command}

processing rtmdet_s_8xb32-300e_coco...
[32mrtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth exists in /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models[0m
[32mSuccessfully dumped rtmdet_s_8xb32-300e_coco.py to /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models[0m


In [13]:
!ls {pretrained_model_directory.as_posix()}

rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth
rtmdet_s_8xb32-300e_coco.py


In [14]:
pretrained_model_directory

PosixPath('/n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models')

### Parameters and dataset

In [None]:
model_name = 'rtmdet_small_8xb32-300e_coco_chronic'

# Where the COCO format dataset is located (created in the previous notebook)
dataset_directory = Path("/n/groups/datta/tim_sainburg/projects/24-04-02-neuropixels-chronic/data/keypoints/coco-trainingsets/240408-mmpose-multianimal-chronic_v3/")

# which config to use (this is what we base the config off of). Should be in the mmdeteciton repo. 
config_loc = Path('/n/groups/datta/tim_sainburg/projects/mmdetection/configs/rtmdet/rtmdet_s_8xb32-300e_coco.py')

# which pretrained model to use (point to .pth file). Pretrained model should be the same model architecture. 
#pretrained_model = pretrained_model_directory / "rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth"
pretrained_model = Path('/n/groups/datta/tim_sainburg/projects/24-01-05-multicamera_keypoints_mm2d/models/rtmdet/rtmdet_tiny_8xb32-300e_coco_24-01-05-11-25-00_102726/epoch_300.pth')
use_pretrained_model = True

# working directory (where model output is saved)
output_directory = Path("/n/groups/datta/tim_sainburg/datasets/scratch/mm_training")
formatted_datetime = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
working_directory = (output_directory / 'rtmdet' / f"{model_name}_{formatted_datetime}")
working_directory.mkdir(parents=True, exist_ok=True)

# You shouldn't need to change anything below here

### Display compute / environment info (for future reference)

In [16]:
# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
gcc (GCC) 9.2.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.



In [17]:
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env
import sys
import mmdet
import torch, torchvision
import mmpose
from mmcv.ops import get_compiling_cuda_version, get_compiler_version

def collect_env():
    """Collect the information of the running environments."""
    env_info = collect_base_env()
    env_info['MMDetection'] = f'{mmdet.__version__}+{get_git_hash()[:7]}'
    return env_info

print(f"Environment: {sys.executable}")
for name, val in collect_env().items():
    print(f'{name}: {val}')
# Check Pytorch installation
print('cuda version:', get_compiling_cuda_version())
print('compiler information:', get_compiler_version())
print('torch version:', torch.__version__, torch.cuda.is_available())
print('torchvision version:', torchvision.__version__)
print('mmpose version:', mmpose.__version__) 

Environment: /n/groups/datta/tim_sainburg/conda_envs/mmdeploy/bin/python3
sys.platform: linux
Python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0: NVIDIA L40S
CUDA_HOME: /n/app/cuda/12.1-gcc-9.2.0
NVCC: Cuda compilation tools, release 12.1, V12.1.105
GCC: gcc (GCC) 9.2.0
PyTorch: 2.1.1
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=

### Create the config file

In [18]:
from mmengine import Config
from pathlib import Path

In [19]:
cfg = Config.fromfile(config_loc.as_posix())

# set the dataset directory
cfg.data_root = dataset_directory.as_posix()

# set the working directory
cfg.work_dir = working_directory.as_posix()

# set head to only care about the mouse class
cfg.model.bbox_head.num_classes = 1

# set the metainfo
cfg.metainfo = {
    'classes': ('Mouse', ),
    'palette': [
        (220, 20, 60),
    ]
}

# specify the dataset
cfg.dataset_type = 'CocoDataset'

# load COCO pre-trained weight
if use_pretrained_model:
    cfg.load_from = pretrained_model.as_posix()

In [20]:
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.metainfo = cfg.metainfo
cfg.train_dataloader.dataset.data_prefix = dict(img='train/')
cfg.train_dataloader.dataset.ann_file = 'annotations/instances_train.json'


cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.metainfo = cfg.metainfo
cfg.val_dataloader.dataset.data_prefix = dict(img='val/')
cfg.val_dataloader.dataset.ann_file = 'annotations/instances_val.json'

cfg.train_dataloader.dataset.type = cfg.dataset_type
cfg.val_dataloader.dataset.type = cfg.dataset_type

cfg.val_evaluator.ann_file= cfg.data_root + '/annotations/instances_val.json'
cfg.test_evaluator.ann_file= cfg.data_root + '/annotations/instances_val.json'

cfg.default_hooks.checkpoint.max_keep_ckpts = 15
cfg.default_hooks.checkpoint.interval = 50

cfg.max_epochs = 2000
cfg.train_cfg.max_epochs = 2000

In [21]:
print(cfg.model)

{'type': 'RTMDet', 'data_preprocessor': {'type': 'DetDataPreprocessor', 'mean': [103.53, 116.28, 123.675], 'std': [57.375, 57.12, 58.395], 'bgr_to_rgb': False, 'batch_augments': None}, 'backbone': {'type': 'CSPNeXt', 'arch': 'P5', 'expand_ratio': 0.5, 'deepen_factor': 0.33, 'widen_factor': 0.5, 'channel_attention': True, 'norm_cfg': {'type': 'SyncBN'}, 'act_cfg': {'type': 'SiLU', 'inplace': True}, 'init_cfg': {'type': 'Pretrained', 'prefix': 'backbone.', 'checkpoint': 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth'}}, 'neck': {'type': 'CSPNeXtPAFPN', 'in_channels': [128, 256, 512], 'out_channels': 128, 'num_csp_blocks': 1, 'expand_ratio': 0.5, 'norm_cfg': {'type': 'SyncBN'}, 'act_cfg': {'type': 'SiLU', 'inplace': True}}, 'bbox_head': {'type': 'RTMDetSepBNHead', 'num_classes': 1, 'in_channels': 128, 'stacked_convs': 2, 'feat_channels': 128, 'anchor_generator': {'type': 'MlvlPointGenerator', 'offset': 0, 'strides': [8, 16, 32]}, '

In [22]:
# save configuration file for future reference
cfg.dump(working_directory / 'config.py')

In [23]:
print(working_directory)

/n/groups/datta/tim_sainburg/datasets/scratch/mm_training/rtmdet/rtmdet_small_8xb32-300e_coco_chronic_24-08-15-10-35-59


### Train

In [24]:
from mmengine.config import Config, DictAction
from mmengine.runner import Runner

In [25]:
# build the runner from config
runner = Runner.from_cfg(cfg)

08/15 10:36:14 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
    CUDA available: True
    numpy_random_seed: 1941649245
    GPU 0: NVIDIA L40S
    CUDA_HOME: /n/app/cuda/12.1-gcc-9.2.0
    NVCC: Cuda compilation tools, release 12.1, V12.1.105
    GCC: gcc (GCC) 9.2.0
    PyTorch: 2.1.1
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-genco

08/15 10:36:19 - mmengine - [4m[97mINFO[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
08/15 10:36:19 - mmengine - [4m[97mINFO[0m - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) RuntimeInfoHook                    
(49          ) EMAHook                            
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
after_load_checkpoint:
(49          ) EMAHook                            
 -------------------- 
before_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(49          ) EMAHook                            
(NORMAL      ) IterTimerHook                      
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DistSamplerSeedHook                
(NORMA

In [None]:
# start training
runner.train()

loading annotations into memory...
Done (t=0.13s)
creating index...
index created!
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.0.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.0.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.1.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.1.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.2.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.2.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage1.0.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage1.0.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine

08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.0.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.1.conv1.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.1.conv1.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.1.conv2.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.1.conv2.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.2.main_conv.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.2.main_conv.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage4.2.short_conv.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO

08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.main_conv.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.short_conv.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.short_conv.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.final_conv.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.final_conv.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.blocks.0.conv1.bn.weight:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_blocks.1.blocks.0.conv1.bn.bias:weight_decay=0.0
08/15 10:36:21 - mmengine - [4m[97mINFO[0m - paramwise_options -- neck.bottom_up_

08/15 10:36:23 - mmengine - [4m[97mINFO[0m - Checkpoints will be saved to /n/groups/datta/tim_sainburg/datasets/scratch/mm_training/rtmdet/rtmdet_small_8xb32-300e_coco_chronic_24-08-15-10-35-59.


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


08/15 10:36:46 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][ 50/416]  base_lr: 1.9623e-04 lr: 1.9623e-04  eta: 16:05:05  time: 0.4642  data_time: 0.0758  memory: 14066  loss: 2.4784  loss_cls: 1.7812  loss_bbox: 0.6972
08/15 10:37:03 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][100/416]  base_lr: 3.9643e-04 lr: 3.9643e-04  eta: 14:08:23  time: 0.3522  data_time: 0.0031  memory: 14066  loss: 1.0126  loss_cls: 0.5054  loss_bbox: 0.5073
08/15 10:37:21 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][150/416]  base_lr: 5.9663e-04 lr: 5.9663e-04  eta: 13:25:39  time: 0.3470  data_time: 0.0029  memory: 14066  loss: 0.8079  loss_cls: 0.3554  loss_bbox: 0.4525
08/15 10:37:38 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][200/416]  base_lr: 7.9683e-04 lr: 7.9683e-04  eta: 13:02:40  time: 0.3442  data_time: 0.0028  memory: 14066  loss: 0.7820  loss_cls: 0.3414  loss_bbox: 0.4407
08/15 10:37:55 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][250/416]  base_lr: 9.9703e-04

08/15 10:47:21 - mmengine - [4m[97mINFO[0m - Epoch(train)   [5][200/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:03:34  time: 0.3472  data_time: 0.0023  memory: 14066  loss: 0.6656  loss_cls: 0.2579  loss_bbox: 0.4076
08/15 10:47:38 - mmengine - [4m[97mINFO[0m - Epoch(train)   [5][250/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:03:07  time: 0.3503  data_time: 0.0025  memory: 14066  loss: 0.6596  loss_cls: 0.2505  loss_bbox: 0.4091
08/15 10:47:56 - mmengine - [4m[97mINFO[0m - Epoch(train)   [5][300/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:02:40  time: 0.3499  data_time: 0.0024  memory: 14066  loss: 0.6589  loss_cls: 0.2488  loss_bbox: 0.4101
08/15 10:48:09 - mmengine - [4m[97mINFO[0m - Exp name: rtmdet_s_8xb32-300e_coco_20240815_103613
08/15 10:48:13 - mmengine - [4m[97mINFO[0m - Epoch(train)   [5][350/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:02:08  time: 0.3483  data_time: 0.0022  memory: 14066  loss: 0.6709  loss_cls: 0.2538  loss_bbox: 0.41

08/15 10:57:39 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][300/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 11:50:28  time: 0.3531  data_time: 0.0025  memory: 14066  loss: 0.6315  loss_cls: 0.2385  loss_bbox: 0.3931
08/15 10:57:57 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][350/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 11:50:11  time: 0.3519  data_time: 0.0024  memory: 14066  loss: 0.6230  loss_cls: 0.2249  loss_bbox: 0.3981
08/15 10:58:14 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][400/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 11:49:44  time: 0.3461  data_time: 0.0022  memory: 14066  loss: 0.6357  loss_cls: 0.2353  loss_bbox: 0.4004
08/15 10:58:19 - mmengine - [4m[97mINFO[0m - Exp name: rtmdet_s_8xb32-300e_coco_20240815_103613
08/15 10:58:39 - mmengine - [4m[97mINFO[0m - Epoch(train)  [10][ 50/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 11:50:13  time: 0.4007  data_time: 0.0447  memory: 14066  loss: 0.6276  loss_cls: 0.2306  loss_bbox: 0.39

08/15 11:08:50 - mmengine - [4m[97mINFO[0m - Exp name: rtmdet_s_8xb32-300e_coco_20240815_103613
08/15 11:08:56 - mmengine - [4m[97mINFO[0m - Exp name: rtmdet_s_8xb32-300e_coco_20240815_103613
08/15 11:09:17 - mmengine - [4m[97mINFO[0m - Epoch(train)  [13][ 50/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:55:34  time: 0.5373  data_time: 0.0556  memory: 14066  loss: 0.6156  loss_cls: 0.2259  loss_bbox: 0.3897
08/15 11:09:40 - mmengine - [4m[97mINFO[0m - Epoch(train)  [13][100/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:56:52  time: 0.4717  data_time: 0.0022  memory: 14066  loss: 0.6284  loss_cls: 0.2307  loss_bbox: 0.3977
08/15 11:10:05 - mmengine - [4m[97mINFO[0m - Epoch(train)  [13][150/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 12:58:42  time: 0.5010  data_time: 0.0025  memory: 14066  loss: 0.6015  loss_cls: 0.2204  loss_bbox: 0.3811
08/15 11:10:31 - mmengine - [4m[97mINFO[0m - Epoch(train)  [13][200/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 13:00:

08/15 11:24:06 - mmengine - [4m[97mINFO[0m - Epoch(train)  [17][150/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 13:42:57  time: 0.5012  data_time: 0.0028  memory: 14066  loss: 0.5813  loss_cls: 0.2071  loss_bbox: 0.3742
08/15 11:24:31 - mmengine - [4m[97mINFO[0m - Epoch(train)  [17][200/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 13:43:55  time: 0.5099  data_time: 0.0031  memory: 14066  loss: 0.5814  loss_cls: 0.2082  loss_bbox: 0.3732
08/15 11:24:57 - mmengine - [4m[97mINFO[0m - Epoch(train)  [17][250/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 13:44:53  time: 0.5117  data_time: 0.0027  memory: 14066  loss: 0.5837  loss_cls: 0.2081  loss_bbox: 0.3755
08/15 11:25:22 - mmengine - [4m[97mINFO[0m - Epoch(train)  [17][300/416]  base_lr: 4.0000e-03 lr: 4.0000e-03  eta: 13:45:44  time: 0.5046  data_time: 0.0026  memory: 14066  loss: 0.5849  loss_cls: 0.2058  loss_bbox: 0.3791
08/15 11:25:45 - mmengine - [4m[97mINFO[0m - Exp name: rtmdet_s_8xb32-300e_coco_20240815_1036

### The config and path for running inference will be in the working directory