# MoveNet_FPN 训练笔记

In [2]:
# %% [markdown]
# # MoveNet_FPN 训练笔记

# %% [markdown]
# ## 导入工程

# %%
# 导入系统库
import os
import timm
from tqdm import tqdm

# 导入sparrow
from sparrow.models.movenet_fpn import MoveNet_FPN, decode_movenet_outputs
from sparrow.datasets.coco_kpts import create_kpts_dataloader
from sparrow.losses.movenet_loss import MoveNetLoss, evaluate_local
from sparrow.utils.ema import EMA
from sparrow.utils.visual_movenet import visualize_movenet

# 导入torch库
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

# %% [markdown]
# ## 参数设置
#
# ### 系统参数

# %%
device = 'cuda' if torch.cuda.is_available() else 'cpu'
INPUT_SIZE = 192
BATCH_SIZE = 8
NUM_WORKERS = 4
NUM_JOINTS = 17
UPSAMPLE = True
TARGET_STRIDE = 4

COCO_ROOT = "./data/coco2017_movenet"            # COCO训练数据集
WEIGHTS_DIR = "./outputs/movenet"                # 保存权重的目录
TEST_IMAGE_PATH = "./res/girl_with_bags.png"     # 测试图片路径

# %% [markdown]
# ### 学习参数

# %%
START_EPOCH = 0
EPOCHS = 100
BEST_VAL_LOSS = float('inf')

WARMUP_EPOCHS = 2               # 线性预热 epoch 数
GRADIENT_CLIP_VAL = 5.0         # 梯度裁剪阈值

LEARNING_RATE = 3e-4
WEIGHT_DECAY  = 1e-4

# %% [markdown]
# ## 创建模型

# %%
backbone_fpn = timm.create_model(
    'mobilenetv3_large_100', pretrained=True, features_only=True, out_indices=(2, 3, 4)
)
model_fpn = MoveNet_FPN(
    backbone_fpn,
    num_joints=NUM_JOINTS,
    fpn_out_channels=128,
    upsample_to_quarter=UPSAMPLE,
    out_stride=TARGET_STRIDE
).to(device)

# EMA 评估器
ema = EMA(model_fpn)

# %% [markdown]
# ## 加载数据

# %%
# 训练数据加载器（建议此处按需开启轻度增强）
train_aug_config = {"use_flip": False, "use_color_aug": False}
train_loader = create_kpts_dataloader(
    dataset_root=COCO_ROOT,
    img_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    target_stride=TARGET_STRIDE,
    pin_memory=True,
    aug_cfg=train_aug_config,
    is_train=True
)

# 验证：严格关闭随机增强，确保曲线稳定
test_aug_config = {"use_flip": False, "use_color_aug": False}
val_loader = create_kpts_dataloader(
    dataset_root=COCO_ROOT,
    img_size=INPUT_SIZE,
    batch_size=BATCH_SIZE * 2,
    num_workers=NUM_WORKERS,
    target_stride=TARGET_STRIDE,
    pin_memory=True,
    aug_cfg=test_aug_config,
    is_train=False
)

# %% [markdown]
# ## 损失/优化/调度

# %%
# ★ 损失函数：开启骨架一致性（bg 先关）
criterion = MoveNetLoss(
    hm_weight=1.0, ct_weight=1.0, reg_weight=1.5, off_weight=1.0,
    bone_weight=0.15,   # 建议 0.10~0.20 之间微调
    bg_weight=0.0
)

# ★ 优化器：推荐 AdamW 更稳（如需用 Adam，把下面一行改回 torch.optim.Adam）
optimizer = torch.optim.AdamW(
    model_fpn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

# 学习率调度器（余弦退火）
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# %% [markdown]
# ## 加载预训练/断点

# %%
os.makedirs(WEIGHTS_DIR, exist_ok=True)

last_pt_path = os.path.join(WEIGHTS_DIR, "last.pt")
if os.path.exists(last_pt_path):
    print("--- Resuming training from last.pt ---")
    checkpoint = torch.load(last_pt_path, map_location=device)

    model_fpn.load_state_dict(checkpoint['model'])
    ema.ema_model.load_state_dict(checkpoint['ema_model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])

    START_EPOCH = checkpoint['epoch'] + 1
    BEST_VAL_LOSS = checkpoint['best_val_loss']

    print(f"Resumed from epoch {START_EPOCH-1}. Best validation loss so far: {BEST_VAL_LOSS:.4f}")
    print(f"Current learning rate is {optimizer.param_groups[0]['lr']:.6f}")

# %% [markdown]
# ## 训练循环

# %%
print("\n--- Starting Training ---")

# 预热步数
warmup_steps = WARMUP_EPOCHS * len(train_loader)
current_step = START_EPOCH * len(train_loader)

for epoch in range(START_EPOCH, EPOCHS):
    model_fpn.train()

    # 统计项
    epoch_loss_heatmap = 0.0
    epoch_loss_center  = 0.0
    epoch_loss_regs    = 0.0
    epoch_loss_offsets = 0.0
    epoch_loss_bone    = 0.0
    epoch_loss_bg      = 0.0

    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    pbar = tqdm(train_loader, desc=f"  🟢 [Training] lr: {optimizer.param_groups[0]['lr']:.6f} ")

    for i, (imgs, labels, kps_masks, _) in enumerate(pbar):
        # 线性预热
        if current_step < warmup_steps:
            lr_scale = (current_step + 1) / max(1, warmup_steps)
            for g in optimizer.param_groups:
                g['lr'] = LEARNING_RATE * lr_scale

        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        kps_masks = kps_masks.to(device, non_blocking=True)

        # --- 把 [B,17] 或 [B,17,1,1] 广播为 [B,17,Hf,Wf] ---
        Hf, Wf = labels.shape[-2], labels.shape[-1]
        if kps_masks.dim() == 2 and kps_masks.shape[1] == 17:
            kps_masks = kps_masks[:, :, None, None].float().expand(-1, -1, Hf, Wf).contiguous()
        elif kps_masks.dim() == 4 and kps_masks.shape[2] == 1 and kps_masks.shape[3] == 1:
            kps_masks = kps_masks.float().expand(-1, -1, Hf, Wf).contiguous()
        # 否则应已是 [B,17,Hf,Wf]，按新版本正常使用

        # 前向
        preds = model_fpn(imgs)
        total_loss, loss_dict = criterion(preds, labels, kps_masks)

        # 反传
        optimizer.zero_grad(set_to_none=True)
        total_loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model_fpn.parameters(), max_norm=GRADIENT_CLIP_VAL)

        # 更新
        optimizer.step()
        ema.update(model_fpn)
        current_step += 1

        # 累计显示
        epoch_loss_heatmap += float(loss_dict["loss_heatmap"])
        epoch_loss_center  += float(loss_dict["loss_center"])
        epoch_loss_regs    += float(loss_dict["loss_regs"])
        epoch_loss_offsets += float(loss_dict["loss_offsets"])
        if criterion.bone_weight > 0:
            epoch_loss_bone  += float(loss_dict["loss_bone"])
        if criterion.bg_weight > 0:
            epoch_loss_bg    += float(loss_dict["loss_bg"])

        pbar.set_postfix(
            hm=f"{epoch_loss_heatmap:.2f}",
            center=f"{epoch_loss_center:.2f}",
            regs=f"{epoch_loss_regs:.2f}",
            offsets=f"{epoch_loss_offsets:.2f}",
            bone=(f"{epoch_loss_bone:.2f}" if criterion.bone_weight > 0 else "0.00"),
            bg=(f"{epoch_loss_bg:.2f}" if criterion.bg_weight > 0 else "0.00"),
        )

    # 调度器步进（放在一个 epoch 结束后）
    if epoch >= WARMUP_EPOCHS - 1:
        scheduler.step()

    # ===== 验证（用 EMA 模型）=====
    avg_total_loss, avg_dict = evaluate_local(
        ema.ema_model, val_loader, criterion, device,
        decoder=decode_movenet_outputs, stride=TARGET_STRIDE
    )
    print(f"  📜 Epoch {epoch+1}/{EPOCHS} average loss: {avg_total_loss:.4f}")

    # ===== 保存权重 =====
    checkpoint = {
        'epoch': epoch,
        'model': model_fpn.state_dict(),
        'ema_model': ema.ema_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'best_val_loss': BEST_VAL_LOSS,
    }
    torch.save(checkpoint, last_pt_path)
    print(f"  🎯 Saved last checkpoint to {last_pt_path}")

    if avg_total_loss < BEST_VAL_LOSS:
        BEST_VAL_LOSS = avg_total_loss
        checkpoint['best_val_loss'] = BEST_VAL_LOSS
        best_pt_path = os.path.join(WEIGHTS_DIR, "best.pt")
        torch.save(checkpoint, best_pt_path)
        print(f"  🎉 New best model found! Saved to {best_pt_path}")

    # ===== 每 5 个 epoch 可视化（用 EMA 模型；只画关键点+骨架）=====
    if (epoch + 1) % 2 == 0:
        print(f"  📊 Visualized predictions on test image")
        viz_dir = os.path.join(WEIGHTS_DIR, "viz")
        os.makedirs(viz_dir, exist_ok=True)

        import cv2
        img_bgr = cv2.imread(TEST_IMAGE_PATH)
        if img_bgr is None:
            raise FileNotFoundError(f"TEST_IMAGE_PATH not found: {TEST_IMAGE_PATH}")

        img_resized = cv2.resize(img_bgr, (600, 600), interpolation=cv2.INTER_LINEAR)
        save_path = os.path.join(viz_dir, f"epoch_{epoch+1:03d}.png")

        visualize_movenet(
            model=ema.ema_model,        # ★ 用 EMA 模型
            image=img_resized,
            device=device,
            decoder=decode_movenet_outputs,
            input_size=INPUT_SIZE,
            stride=TARGET_STRIDE,
            topk_centers=3,
            center_thresh=0.10,
            keypoint_thresh=0.03,
            draw_on_orig=True,
            draw_heatmaps=True,
            save_path=save_path,
            show=False,
            # 如果你的 visualize_movenet 仍支持 bbox 开关：
            # 只画关键点+骨架
            # draw_bbox=False,
            # draw_skeleton=True,
            # 若集成了单人筛选（force_single），可按需加：
            # force_single=True
        )

print("--- Training Finished ---")


Unexpected keys (classifier.bias, classifier.weight, conv_head.bias, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


--- Resuming training from last.pt ---
Resumed from epoch 4. Best validation loss so far: 0.0041
Current learning rate is 0.000299

--- Starting Training ---

Epoch 6/100


  🟢 [Training] lr: 0.000299 : 100%|██████████| 16220/16220 [06:04<00:00, 44.45it/s, bg=0.00, bone=9.76, center=8.44, hm=3.47, offsets=2.71, regs=36.83]
  🟡 [Validating] : 100%|██████████| 340/340 [00:05<00:00, 66.50it/s, bg=0.000000, bone=0.000590, ct=0.000525, hm=0.000210, off=0.000140, pck=1.50%, reg=0.002191, tot=0.004250]


  📜 Epoch 6/100 average loss: 0.0042
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  📊 Visualized predictions on test image

Epoch 7/100


  🟢 [Training] lr: 0.000298 :  67%|██████▋   | 10795/16220 [04:04<02:02, 44.14it/s, bg=0.00, bone=6.48, center=5.60, hm=2.30, offsets=1.75, regs=24.52]


KeyboardInterrupt: 