### Train an MMPose Network
- note that I am using the conda env at /n/groups/datta/tim_sainburg/conda_envs/openmmlab

In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
from pathlib import Path
import sys
from datetime import datetime

### Download checkpoint for a pretrained model (if desired)
Alternatively, use a previous mouse model as a pretrained model

In [3]:
pretrained_model_directory = Path("/n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models")
pretrained_model_directory.mkdir(parents=True, exist_ok=True)

In [4]:
# find models here: https://github.com/open-mmlab/mmpose/tree/main/configs
pretrain_model = "rtmpose-m_8xb64-210e_ap10k-256x256"

In [5]:
command = f"source activate {Path(sys.executable).parents[1]}; mim download mmpose --config {pretrain_model} --dest {pretrained_model_directory.as_posix()}"
print(command)

source activate /n/groups/datta/tim_sainburg/conda_envs/openmmlab; mim download mmpose --config rtmpose-m_8xb64-210e_ap10k-256x256 --dest /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models


In [6]:
!{command}

processing rtmpose-m_8xb64-210e_ap10k-256x256...
[32mrtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth exists in /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models[0m
[32mSuccessfully dumped rtmpose-m_8xb64-210e_ap10k-256x256.py to /n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models[0m


In [7]:
!ls {pretrained_model_directory.as_posix()}

rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth
rtmdet_s_8xb32-300e_coco.py
rtmpose-m_8xb64-210e_ap10k-256x256.py
rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth


In [8]:
pretrained_model_directory

PosixPath('/n/groups/datta/tim_sainburg/datasets/scratch/pretrained_mm_models')

### Parameters and dataset

In [9]:
model_name = 'rtmpose-m_8xb64-210e_ap10k-256x256'

# Where the COCO format dataset is located (created in the previous notebook)
dataset_directory = Path("/n/groups/datta/tim_sainburg/projects/24-04-02-neuropixels-chronic/data/keypoints/coco-trainingsets/240408-mmpose-multianimal-chronic_v3/")

# which config to use (this is what we base the config off of). Should be in the mmpose repo. 
config_loc = Path('/n/groups/datta/tim_sainburg/projects/mmpose/configs/animal_2d_keypoint/rtmpose/ap10k/rtmpose-m_8xb64-210e_ap10k-256x256.py')

# which pretrained model to use (point to .pth file). Pretrained model should be the same model architecture. 
#pretrained_model = pretrained_model_directory / "rtmpose-m_simcc-ap10k_pt-aic-coco_210e-256x256-7a041aa1_20230206.pth"
pretrained_model = '/n/groups/datta/tim_sainburg/projects/24-01-05-multicamera_keypoints_mm2d/models/rtmpose/rtmpose-m_8xb64-210e_ap10k-256x256_24-01-05-13-46-05_748568/best_PCK_epoch_230.pth'
use_pretrained_model = False

# working directory (where model output is saved)
output_directory = Path("/n/groups/datta/tim_sainburg/datasets/scratch/mm_training")
formatted_datetime = datetime.now().strftime("%y-%m-%d-%H-%M-%S")
working_directory = (output_directory / 'rtmpose' / f"{model_name}_{formatted_datetime}")
working_directory.mkdir(parents=True, exist_ok=True)

In [10]:
assert config_loc.exists()
assert dataset_directory.exists()

# You shouldn't need to change anything below here
(unless you are using a different skeleton model)

### Display compute / environment info (for future reference)

In [11]:
# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0
gcc (GCC) 9.2.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.



In [12]:
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env
import sys
import mmdet
import torch, torchvision
import mmpose
from mmcv.ops import get_compiling_cuda_version, get_compiler_version

def collect_env():
    """Collect the information of the running environments."""
    env_info = collect_base_env()
    env_info['MMDetection'] = f'{mmdet.__version__}+{get_git_hash()[:7]}'
    return env_info

print(f"Environment: {sys.executable}")
for name, val in collect_env().items():
    print(f'{name}: {val}')
# Check Pytorch installation
print('cuda version:', get_compiling_cuda_version())
print('compiler information:', get_compiler_version())
print('torch version:', torch.__version__, torch.cuda.is_available())
print('torchvision version:', torchvision.__version__)
print('mmpose version:', mmpose.__version__) 

Environment: /n/groups/datta/tim_sainburg/conda_envs/openmmlab/bin/python3
sys.platform: linux
Python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
CUDA available: True
numpy_random_seed: 2147483648
GPU 0: NVIDIA L40S
CUDA_HOME: /n/app/cuda/12.1-gcc-9.2.0
NVCC: Cuda compilation tools, release 12.1, V12.1.105
GCC: gcc (GCC) 9.2.0
PyTorch: 2.1.0
PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code

### Register the new dataset

In [13]:
from mmpipeline.paths import PACKAGE_DIR

In [14]:
from mmpose.registry import DATASETS
from mmpose.datasets.datasets.base import BaseCocoStyleDataset

In [15]:
# this file contains info about the dataset (keypoints, skeleton, etc) needed for traiing
dataset_info_loc =  Path("/n/groups/datta/tim_sainburg/projects/multicamera_airflow_pipeline/multicamera_airflow_pipeline/tim_240731/skeletons/sainburg25pt.py")

In [16]:
@DATASETS.register_module()
class CoCo25pt(BaseCocoStyleDataset):
    METAINFO: dict = dict(from_file=dataset_info_loc)

### Create config file

In [17]:
from mmengine import Config

In [18]:
cfg = Config.fromfile(config_loc.as_posix())

In [19]:
# load COCO pre-trained weight
if use_pretrained_model:
    cfg.load_from = pretrained_model.as_posix()

In [20]:
# set the dataset directory
cfg.data_root = dataset_directory.as_posix()

# set the working directory
cfg.work_dir = working_directory.as_posix()
cfg.randomness = dict(seed=0)

In [21]:
# set dataset configs
cfg.dataset_type = 'CoCo25pt'
cfg.data_mode = 'topdown'

# number of keypoints
cfg.model.head.out_channels = 25

cfg.train_dataloader.dataset.type = cfg.dataset_type
cfg.train_dataloader.dataset.ann_file = 'annotations/instances_train.json'
cfg.train_dataloader.dataset.data_root = cfg.data_root
cfg.train_dataloader.dataset.data_prefix = dict(img='train/')


cfg.val_dataloader.dataset.type = cfg.dataset_type
cfg.val_dataloader.dataset.bbox_file = None
cfg.val_dataloader.dataset.ann_file = 'annotations/instances_val.json'
cfg.val_dataloader.dataset.data_root = cfg.data_root
cfg.val_dataloader.dataset.data_prefix = dict(img='val/')

cfg.test_dataloader.dataset.type = cfg.dataset_type
cfg.test_dataloader.dataset.bbox_file = None
cfg.test_dataloader.dataset.ann_file = 'annotations/instances_val.json'
cfg.test_dataloader.dataset.data_root = cfg.data_root
cfg.test_dataloader.dataset.data_prefix = dict(img='val/')

# set to custom datset
cfg.train_dataloader.dataset.metainfo = dict(from_file=dataset_info_loc.as_posix())
cfg.val_dataloader.dataset.metainfo = dict(from_file=dataset_info_loc.as_posix())
cfg.test_dataloader.dataset.metainfo = dict(from_file=dataset_info_loc.as_posix())

# set evaluator
cfg.val_evaluator = dict(type='PCKAccuracy')
cfg.test_evaluator = cfg.val_evaluator

cfg.default_hooks.checkpoint.save_best = 'PCK'
cfg.default_hooks.checkpoint.max_keep_ckpts = 15
cfg.default_hooks.checkpoint.interval = 50

cfg.max_epochs = 2000
cfg.train_cfg.max_epochs = 2000

In [22]:
print(cfg)

Config (path: /n/groups/datta/tim_sainburg/projects/mmpose/configs/animal_2d_keypoint/rtmpose/ap10k/rtmpose-m_8xb64-210e_ap10k-256x256.py): {'default_scope': 'mmpose', 'default_hooks': {'timer': {'type': 'IterTimerHook'}, 'logger': {'type': 'LoggerHook', 'interval': 50}, 'param_scheduler': {'type': 'ParamSchedulerHook'}, 'checkpoint': {'type': 'CheckpointHook', 'interval': 10, 'save_best': 'PCK', 'rule': 'greater', 'max_keep_ckpts': 1}, 'sampler_seed': {'type': 'DistSamplerSeedHook'}, 'visualization': {'type': 'PoseVisualizationHook', 'enable': False}, 'badcase': {'type': 'BadCaseAnalysisHook', 'enable': False, 'out_dir': 'badcase', 'metric_type': 'loss', 'badcase_thr': 5}}, 'custom_hooks': [{'type': 'EMAHook', 'ema_type': 'ExpMomentumEMA', 'momentum': 0.0002, 'update_buffers': True, 'priority': 49}, {'type': 'mmdet.PipelineSwitchHook', 'switch_epoch': 180, 'switch_pipeline': [{'type': 'LoadImage', 'backend_args': {'backend': 'local'}}, {'type': 'GetBBoxCenterScale'}, {'type': 'RandomF

In [23]:
# set preprocess configs to model
cfg.model.setdefault('data_preprocessor', cfg.get('preprocess_cfg', {}))

{'type': 'PoseDataPreprocessor',
 'mean': [123.675, 116.28, 103.53],
 'std': [58.395, 57.12, 57.375],
 'bgr_to_rgb': True}

In [24]:
# save configuration file for future reference
cfg.dump(working_directory / 'config.py')

### run network

In [25]:
from mmengine.config import Config, DictAction
from mmengine.runner import Runner

In [26]:
# build the runner from config
runner = Runner.from_cfg(cfg)

08/15 10:58:38 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0]
    CUDA available: True
    numpy_random_seed: 0
    GPU 0: NVIDIA L40S
    CUDA_HOME: /n/app/cuda/12.1-gcc-9.2.0
    NVCC: Cuda compilation tools, release 12.1, V12.1.105
    GCC: gcc (GCC) 9.2.0
    PyTorch: 2.1.0
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2023.1-Product Build 20230303 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX512
  - CUDA Runtime 12.1
  - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=c

08/15 10:58:43 - mmengine - [4m[97mINFO[0m - Distributed training is not used, all SyncBatchNorm (SyncBN) layers in the model will be automatically reverted to BatchNormXd layers if they are used.
08/15 10:58:43 - mmengine - [4m[97mINFO[0m - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) RuntimeInfoHook                    
(49          ) EMAHook                            
(BELOW_NORMAL) LoggerHook                         
 -------------------- 
after_load_checkpoint:
(49          ) EMAHook                            
 -------------------- 
before_train:
(VERY_HIGH   ) RuntimeInfoHook                    
(49          ) EMAHook                            
(NORMAL      ) IterTimerHook                      
(VERY_LOW    ) CheckpointHook                     
 -------------------- 
before_train_epoch:
(VERY_HIGH   ) RuntimeInfoHook                    
(NORMAL      ) IterTimerHook                      
(NORMAL      ) DistSamplerSeedHook                
(NORMA

In [None]:
# start training
runner.train()



loading annotations into memory...
Done (t=0.31s)
creating index...
index created!
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.0.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.0.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.1.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.1.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.2.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stem.2.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage1.0.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage1.0.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine

08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.main_conv.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.short_conv.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.short_conv.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.final_conv.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.final_conv.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.blocks.0.conv1.bn.weight:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.blocks.0.conv1.bn.bias:weight_decay=0.0
08/15 10:58:46 - mmengine - [4m[97mINFO[0m - paramwise_options -- backbone.stage3.1.blocks.0.conv2.depthwise_conv.bn.weight

08/15 10:58:49 - mmengine - [4m[97mINFO[0m - Checkpoints will be saved to /n/groups/datta/tim_sainburg/datasets/scratch/mm_training/rtmpose/rtmpose-m_8xb64-210e_ap10k-256x256_24-08-15-10-58-18.
08/15 10:59:11 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][ 50/222]  base_lr: 1.962342e-04 lr: 1.962342e-04  eta: 5:49:49  time: 0.450716  data_time: 0.150470  memory: 5699  loss: 0.549574  loss_kpt: 0.549574  acc_pose: 0.026947
08/15 10:59:29 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][100/222]  base_lr: 3.964324e-04 lr: 3.964324e-04  eta: 5:11:11  time: 0.352012  data_time: 0.091605  memory: 5699  loss: 0.518820  loss_kpt: 0.518820  acc_pose: 0.052520
08/15 10:59:48 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][150/222]  base_lr: 5.966306e-04 lr: 5.966306e-04  eta: 5:08:19  time: 0.391583  data_time: 0.129223  memory: 5699  loss: 0.479659  loss_kpt: 0.479659  acc_pose: 0.108368
08/15 11:00:06 - mmengine - [4m[97mINFO[0m - Epoch(train)   [1][200/222]  base_lr: 7.9682

08/15 11:09:44 - mmengine - [4m[97mINFO[0m - Epoch(train)   [8][200/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:39:17  time: 0.364116  data_time: 0.086701  memory: 5699  loss: 0.529793  loss_kpt: 0.529793  acc_pose: 0.014900
08/15 11:09:51 - mmengine - [4m[97mINFO[0m - Exp name: rtmpose-m_8xb64-210e_ap10k-256x256_20240815_105837
08/15 11:10:11 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][ 50/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:39:06  time: 0.404399  data_time: 0.130727  memory: 5699  loss: 0.531260  loss_kpt: 0.531260  acc_pose: 0.014127
08/15 11:10:30 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][100/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:38:32  time: 0.361089  data_time: 0.089166  memory: 5699  loss: 0.532509  loss_kpt: 0.532509  acc_pose: 0.019789
08/15 11:10:48 - mmengine - [4m[97mINFO[0m - Epoch(train)   [9][150/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:38:05  time: 0.366210  data_time: 0.092627  memory: 5699

08/15 11:19:38 - mmengine - [4m[97mINFO[0m - Exp name: rtmpose-m_8xb64-210e_ap10k-256x256_20240815_105837
08/15 11:19:58 - mmengine - [4m[97mINFO[0m - Epoch(train)  [16][ 50/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:27:41  time: 0.385563  data_time: 0.129157  memory: 5699  loss: 0.527768  loss_kpt: 0.527768  acc_pose: 0.027489
08/15 11:20:16 - mmengine - [4m[97mINFO[0m - Epoch(train)  [16][100/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:27:13  time: 0.356540  data_time: 0.092461  memory: 5699  loss: 0.533537  loss_kpt: 0.533537  acc_pose: 0.025213
08/15 11:20:33 - mmengine - [4m[97mINFO[0m - Epoch(train)  [16][150/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:26:40  time: 0.348735  data_time: 0.097875  memory: 5699  loss: 0.530388  loss_kpt: 0.530388  acc_pose: 0.016880
08/15 11:20:51 - mmengine - [4m[97mINFO[0m - Epoch(train)  [16][200/222]  base_lr: 4.000000e-03 lr: 4.000000e-03  eta: 4:26:18  time: 0.364067  data_time: 0.090022  memory: 5699

### The config and path for running inference will be in the working directory