In [5]:
import logging
import os
import torch
from tensorboardX import SummaryWriter
from torch import nn, optim
from torch.utils.data import DataLoader

from jmodt.config import cfg, print_config_to_log, cfg_from_list
from jmodt.detection.datasets.kitti_dataset import KittiDataset
from jmodt.detection.modeling import train_functions
from jmodt.detection.modeling.point_rcnn import PointRCNN
from jmodt.utils import train_utils

In [6]:
# Manually define arguments instead of argparse
class Args:
    data_root = "data/"
    challenge = "tracking"
    finetune = True
    batch_size = 4
    output_dir = "output2"
    ckpt = None
    mgpus = False
    train_with_eval = False
    set_cfgs = None


args = Args()  # Instantiate the argument class

In [7]:
def create_logger(log_file):
    log_format = "%(asctime)s  %(levelname)5s  %(message)s"
    logging.basicConfig(level=logging.DEBUG, format=log_format, filename=log_file)
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    console.setFormatter(logging.Formatter(log_format))
    logging.getLogger(__name__).addHandler(console)
    return logging.getLogger(__name__)

In [8]:
def create_dataloader(logger, split):
    data_set = KittiDataset(
        root_dir=args.data_root,
        npoints=cfg.RPN.NUM_POINTS,
        split=split,
        mode="TRAIN",
        logger=logger,
        classes=cfg.CLASSES,
        challenge=args.challenge,
    )
    data_loader = DataLoader(
        data_set,
        batch_size=args.batch_size,
        pin_memory=True,
        shuffle=True,
        num_workers=4,
        collate_fn=data_set.collate_batch,
        drop_last=True,
    )
    return data_set, data_loader

In [9]:
def main():
    if args.set_cfgs is not None:
        cfg_from_list(args.set_cfgs)
    if args.finetune:
        cfg.RPN.FIXED = True
        cfg.TRAIN.FINETUNE = True
    else:
        cfg.RPN.FIXED = False
        cfg.TRAIN.FINETUNE = False

    root_result_dir = args.output_dir
    os.makedirs(root_result_dir, exist_ok=True)

    log_file = os.path.join(root_result_dir, "log_train.txt")
    logger = create_logger(log_file)
    logger.info("**********************Start logging**********************")

    # Log CUDA devices
    gpu_list = os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "ALL"
    logger.info(f"CUDA_VISIBLE_DEVICES={gpu_list}")

    # Log arguments
    for key, val in vars(args).items():
        logger.info(f"{key:16} {val}")

    print_config_to_log(cfg, logger=logger)

    # Tensorboard log
    tb_log = SummaryWriter(logdir=os.path.join(root_result_dir, "tensorboard"))

    # Create dataloader, network, and optimizer
    train_set, train_loader = create_dataloader(logger, split=cfg.TRAIN.SPLIT)
    val_set, val_loader = (
        create_dataloader(logger, split=cfg.TRAIN.VAL_SPLIT)
        if args.train_with_eval
        else (None, None)
    )

    fn_decorator = train_functions.model_joint_fn_decorator()

    model = PointRCNN(num_classes=train_set.num_class, use_xyz=True, mode="TRAIN")
    if args.mgpus:
        model = nn.DataParallel(model)
    model.cuda()
    params_to_update = model.parameters()

    start_epoch = it = 0
    last_epoch = -1
    if args.ckpt is not None:
        pure_model = model.module if isinstance(model, torch.nn.DataParallel) else model
        if cfg.TRAIN.FINETUNE:
            for param in pure_model.parameters():
                param.requires_grad = False
            params_to_update = \
                list(pure_model.rcnn_net.link_layer.parameters()) + \
                list(pure_model.rcnn_net.se_layer.parameters())
            for param in params_to_update:
                param.requires_grad = True
            optimizer = optim.AdamW([
                {'params': pure_model.rcnn_net.link_layer.parameters()},
                {'params': pure_model.rcnn_net.se_layer.parameters()},
            ], lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
        else:
            optimizer = optim.AdamW(params_to_update, lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY)
        if cfg.TRAIN.RELOAD_OPTIMIZER:
            it, start_epoch = train_utils.load_checkpoint(pure_model, optimizer, filename=args.ckpt, logger=logger)
            last_epoch = start_epoch + 1
        else:
            train_utils.load_checkpoint(pure_model, optimizer=None, filename=args.ckpt, logger=logger)
    else:
        optimizer = optim.AdamW(params_to_update, lr=cfg.TRAIN.LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY)

    lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.TRAIN.TMAX,
                                                        eta_min=cfg.TRAIN.ETA_MIN, last_epoch=last_epoch)

    # start training
    logger.info('**********************Start training**********************')
    ckpt_dir = os.path.join(root_result_dir, 'ckpt')
    os.makedirs(ckpt_dir, exist_ok=True)
    trainer = train_utils.Trainer(
        model=model,
        params_to_update=params_to_update,
        model_fn_train=fn_decorator,
        optimizer=optimizer,
        ckpt_dir=ckpt_dir,
        lr_scheduler=lr_scheduler,
        model_fn_val=fn_decorator,
        tb_log=tb_log,
        eval_frequency=1,
        grad_norm_clip=cfg.TRAIN.GRAD_NORM_CLIP
    )

    trainer.train(
        it,
        start_epoch,
        cfg.TRAIN.EPOCHS,
        train_loader,
        val_loader
    )

    logger.info('**********************End training**********************')

In [None]:
# Call the main function directly
if __name__ == "__main__":
    main()

2024-11-19 17:40:26,866   INFO  **********************Start logging**********************
2024-11-19 17:40:26,868   INFO  CUDA_VISIBLE_DEVICES=ALL
2024-11-19 17:40:26,870   INFO  cfg.TAG: default
2024-11-19 17:40:26,871   INFO  cfg.CLASSES: Car
2024-11-19 17:40:26,871   INFO  cfg.INCLUDE_SIMILAR_TYPE: True
2024-11-19 17:40:26,873   INFO  cfg.AUG_DATA: False
2024-11-19 17:40:26,874   INFO  cfg.AUG_METHOD_LIST: ['rotation', 'scaling', 'flip']
2024-11-19 17:40:26,875   INFO  cfg.AUG_METHOD_PROB: [1.0, 1.0, 0.5]
2024-11-19 17:40:26,876   INFO  cfg.AUG_ROT_RANGE: 18
2024-11-19 17:40:26,877   INFO  cfg.GT_AUG_ENABLED: False
2024-11-19 17:40:26,879   INFO  cfg.GT_EXTRA_NUM: 15
2024-11-19 17:40:26,880   INFO  cfg.GT_AUG_RAND_NUM: True
2024-11-19 17:40:26,881   INFO  cfg.GT_AUG_APPLY_PROB: 1.0
2024-11-19 17:40:26,882   INFO  cfg.GT_AUG_HARD_RATIO: 0.6
2024-11-19 17:40:26,883   INFO  cfg.PC_REDUCE_BY_RANGE: True
2024-11-19 17:40:26,886   INFO  cfg.PC_AREA_SCOPE: [[-40.   40. ]
 [ -1.    3. ]
 [ 