In [2]:
from backbones.mobilenetv3 import model_splitter
from icevision.all import *
from icevision.models.mmdet.utils import *
from mmdet.models import build_detector
from mmcv import Config
from pprint import pprint


In [3]:
model_name = "mobilenetv3_large_100_aa"
base_config_path = mmdet_configs_path / "retinanet"
config_path = base_config_path / "retinanet_r50_fpn_1x_coco.py"
cfg = Config.fromfile(config_path)

cfg.model.backbone = dict(
    type=f"TIMM_{model_name}",
    pretrained=True,
    out_indices=(0, 1, 2, 3, 4),
)
cfg.model.neck.in_channels = [16, 24, 40, 112, 960]

m = build_detector(cfg.model)

2021-05-12 00:03:05,834 - mmdet - INFO - load model from: torchvision://resnet50


## Set LR Scheduler & Optimizer

In [4]:
LRs = dict(
    stem=1e-6,
    blocks=[1e-5, 1e-4, 1e-4, 1e-3, 1e-3, 1e-3, 1e-3],
    neck=1e-2,
    bbox_head=1e-2,
    classifier_heads=None,
)
optimizer = torch.optim.SGD(
    model_splitter(
        m, 
        LR_stem = LRs["stem"],
        LR_blocks = LRs["blocks"],
        LR_neck = LRs["neck"],
        LR_bbox_head = LRs["bbox_head"],
        LR_classifier_heads = LRs["classifier_heads"]
    ), lr=0.01, momentum=0.9, weight_decay=0.0001
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[16, 22])


In [5]:
optimizer.param_groups[2]['name']
optimizer.param_groups[2]['lr']

0.0001

In [6]:
NUM_EPOCHS = 24
BATCHES_PER_EPOCH = 2000
WARMUP_ITERS = 500

for epoch in tqdm(range(NUM_EPOCHS), desc="Epochs:"):
    #for batch in tqdm(train_dl, desc="Batch #"):
    for batch_idx, batch in enumerate(range(BATCHES_PER_EPOCH)):
        # print(scheduler.get_last_lr())
        if (epoch == 0) and (batch_idx < WARMUP_ITERS):
            lr_scale = min(1.0, float(batch_idx + 1) / WARMUP_ITERS)
            for pg in optimizer.param_groups:
                if "block_idx" in pg.keys():
                    pg["lr"] = lr_scale * LRs["blocks"][pg["block_idx"]]
                else:
                    pg["lr"] = lr_scale * LRs[pg["name"]]
                if pg["name"] == "bbox_head":
                    print(f'{pg["name"]} @batch#{batch_idx}: {pg["lr"]}')
    scheduler.step()
    break

Epochs::   0%|          | 0/24 [00:00<?, ?it/s]bbox_head @batch#0: 2e-05
bbox_head @batch#1: 4e-05
bbox_head @batch#2: 6e-05
bbox_head @batch#3: 8e-05
bbox_head @batch#4: 0.0001
bbox_head @batch#5: 0.00012
bbox_head @batch#6: 0.00014000000000000001
bbox_head @batch#7: 0.00016
bbox_head @batch#8: 0.00017999999999999998
bbox_head @batch#9: 0.0002
bbox_head @batch#10: 0.00021999999999999998
bbox_head @batch#11: 0.00024
bbox_head @batch#12: 0.00026
bbox_head @batch#13: 0.00028000000000000003
bbox_head @batch#14: 0.0003
bbox_head @batch#15: 0.00032
bbox_head @batch#16: 0.00034
bbox_head @batch#17: 0.00035999999999999997
bbox_head @batch#18: 0.00038
bbox_head @batch#19: 0.0004
bbox_head @batch#20: 0.00042
bbox_head @batch#21: 0.00043999999999999996
bbox_head @batch#22: 0.00046
bbox_head @batch#23: 0.00048
bbox_head @batch#24: 0.0005
bbox_head @batch#25: 0.00052
bbox_head @batch#26: 0.00054
bbox_head @batch#27: 0.0005600000000000001
bbox_head @batch#28: 0.00058
bbox_head @batch#29: 0.0006
bbo

In [43]:
LRs["blocks"][0]

1e-05

In [37]:
scheduler.get_last_lr()

[1e-06, 1e-05, 0.0001, 0.0001, 0.001, 0.001, 0.001, 0.001, 0.01, 0.01]