# MoveNet_FPN 训练笔记

## 导入工程

In [1]:
# 导入系统库
import os
import timm
from tqdm import tqdm

# 导入sparrow
from sparrow.models.movenet_fpn import MoveNet_FPN, decode_movenet_outputs
from sparrow.datasets.coco_kpts import create_kpts_dataloader
from sparrow.losses.movenet_loss import MoveNetLoss, evaluate
from sparrow.utils.ema import EMA
from sparrow.utils.visual_movenet import visualize_movenet

# 导入torch库
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

  from .autonotebook import tqdm as notebook_tqdm


## 参数设置

### 系统参数

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
INPUT_SIZE = 192
BATCH_SIZE = 8
NUM_WORKERS = 4
NUM_JOINTS = 17
UPSAMPLE = True
TARGET_STRIDE=4

COCO_ROOT = "./data/coco2017_movenet"       # COCO训练数据集
WEIGHTS_DIR = "./outputs/movenet"           # 保存权重的目录
TEST_IMAGE_PATH = "./res/girl_with_bags.png"    # 测试图片路径

### 学习参数

In [None]:
START_EPOCH = 0
EPOCHS=100                      # 训练次数
BEST_VAL_LOSS = float('inf')
# LEARNING_RATE = 1e-4            # 初始学习率
# WEIGHT_DECAY = 1e-3
WARMUP_EPOCHS = 2               # 预热
GRADIENT_CLIP_VAL = 5.0         # 梯度裁剪的阈值

LEARNING_RATE = 3e-4         # from 1e-4 -> 3e-4
WEIGHT_DECAY  = 1e-4         # from 1e-3 -> 1e-4 (或 0)

## 创建模型

In [None]:
backbone_fpn = timm.create_model('mobilenetv3_large_100', pretrained=True, features_only=True, out_indices=(2, 3, 4))
model_fpn = MoveNet_FPN(backbone_fpn, num_joints=NUM_JOINTS, fpn_out_channels=128, upsample_to_quarter=UPSAMPLE)
model_fpn.to(device)

# EMA评估器
ema = EMA(model_fpn)

Unexpected keys (classifier.bias, classifier.weight, conv_head.bias, conv_head.weight) found while loading pretrained weights. This may be expected if model is being adapted.


## 加载数据

In [None]:
# 创建训练数据加载器 (来自 dataloader.py)
train_aug_config = { "use_flip": True, "use_color_aug": True }
train_loader = create_kpts_dataloader(
    dataset_root=COCO_ROOT,
    img_size=INPUT_SIZE,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    target_stride=TARGET_STRIDE,
    pin_memory=True,
    aug_cfg=train_aug_config,
    is_train=True
)

# 创建验证集数据加载器
test_aug_config = {"use_flip": True}
val_loader = create_kpts_dataloader(
    dataset_root=COCO_ROOT,
    img_size=INPUT_SIZE,
    batch_size=BATCH_SIZE * 2,  # 验证时通常可以用更大的 batch size
    num_workers=NUM_WORKERS,
    target_stride=TARGET_STRIDE,
    pin_memory=True,
    aug_cfg=test_aug_config,
    is_train=False
)

  transforms.append(A.PadIfNeeded(
  original_init(self, **validated_kwargs)
  transforms.append(A.ShiftScaleRotate(


## 损失优化调度

In [None]:
# 损失函数
# criterion = MoveNetLoss(reg_weight=2.0, off_weight=1.0)
criterion = MoveNetLoss(hm_weight=1.0, ct_weight=1.0, reg_weight=1.5, off_weight=1.0)

# 优化器
# optimizer = torch.optim.AdamW(model_fpn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

optimizer = torch.optim.Adam(  # Adam 更宽松；若继续用 AdamW 也行
    model_fpn.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

# 学习调度器
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6) 

## 加载预训练权重

In [7]:
# 确保存放预训练的目录存在
os.makedirs(WEIGHTS_DIR, exist_ok=True) # 确保目录存在

# 断点续训逻辑
last_pt_path = os.path.join(WEIGHTS_DIR, "last.pt")
if os.path.exists(last_pt_path):
    print("--- Resuming training from last.pt ---")

    # 加载pt文件
    checkpoint = torch.load(last_pt_path, map_location=device)
    
    # 从pt中读取模型权重
    model_fpn.load_state_dict(checkpoint['model'])
    
    # 加载EMA状态
    ema.ema_model.load_state_dict(checkpoint['ema_model'])

    # 加载优化器状态
    optimizer.load_state_dict(checkpoint['optimizer'])

    # 加载调度器状态
    scheduler.load_state_dict(checkpoint['scheduler'])

    # 更新EPOCH状态
    START_EPOCH = checkpoint['epoch'] + 1
    
    # 更新最佳损失状态
    BEST_VAL_LOSS = checkpoint['best_val_loss']
    
    # 打印确认消息
    print(f"Resumed from epoch {START_EPOCH-1}. Best validation loss so far: {BEST_VAL_LOSS:.4f}")
    print(f"Current learning rate is {optimizer.param_groups[0]['lr']:.6f}")

## 训练循环

In [None]:
# --- 训练循环 ---
print("\n--- Starting Training ---")

# 计算预热的总步数
warmup_steps = WARMUP_EPOCHS * len(train_loader)
current_step = START_EPOCH * len(train_loader)

for epoch in range(START_EPOCH, EPOCHS):
    model_fpn.train() # 设置为训练模式

    epoch_loss_heatmap = 0.0
    epoch_loss_center = 0.0
    epoch_loss_regs = 0.0
    epoch_loss_offsets = 0.0

    # 进度条信息
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    pbar = tqdm(train_loader, desc=f"  🟢 [Training] lr: {optimizer.param_groups[0]['lr']:.6f} ")

    for i, (imgs, labels, kps_masks, _) in enumerate(pbar):
    
        # 学习率预热逻辑
        if current_step < warmup_steps:
            # 线性预热
            lr_scale = (current_step + 1) / warmup_steps
            for param_group in optimizer.param_groups:
                param_group['lr'] = LEARNING_RATE * lr_scale
        
        # 正常训练步骤
        imgs = imgs.to(device)
        labels = labels.to(device)
        kps_masks = kps_masks.to(device)

        # 前向传播
        preds = model_fpn(imgs)

        # 计算损失
        total_loss, loss_dict = criterion(preds, labels, kps_masks) 

        # 反向传播和优化
        optimizer.zero_grad()
        total_loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model_fpn.parameters(), max_norm=GRADIENT_CLIP_VAL)
        
        # 更新模型参数
        optimizer.step()

        # 更新EMA
        ema.update(model_fpn)
        current_step += 1

        epoch_loss_center += loss_dict["loss_center"]
        epoch_loss_heatmap += loss_dict["loss_heatmap"]
        epoch_loss_offsets += loss_dict["loss_offsets"]
        epoch_loss_regs += loss_dict["loss_regs"]

        # 显示当前信息
        pbar.set_postfix(hm=f"{epoch_loss_heatmap:.2f}", 
                         center=f"{epoch_loss_center:.2f}", 
                         offsets=f"{epoch_loss_offsets:.2f}",
                         regs=f"{epoch_loss_regs:.2f}")
    # end-for: 训练结束

    # 每个 epoch 结束后，更新学习率调度器
    if epoch >= WARMUP_EPOCHS - 1: # -1 是因为 step() 应在 optimizer.step() 之后调用
        scheduler.step()

    # 每个 epoch 结束后，进行验证
    avg_total_loss, _, = evaluate(ema.ema_model, val_loader, criterion, device, decoder=decode_movenet_outputs)

    # 生成本次epoch报告
    print(f"  📜 Epoch {epoch+1}/{EPOCHS} average loss: {avg_total_loss:.4f}")

    # 保存 last.pt 和 best.pt
    checkpoint = {
        'epoch': epoch,
        'model': model_fpn.state_dict(),
        'ema_model': ema.ema_model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict(),
        'best_val_loss': BEST_VAL_LOSS,
    }

    # 保存 last.pt
    torch.save(checkpoint, last_pt_path)
    print(f"  🎯 Saved last checkpoint to {last_pt_path}")
    
    # 如果当前是最佳模型，则保存 best.pt
    if avg_total_loss < BEST_VAL_LOSS:
        BEST_VAL_LOSS = avg_total_loss
        checkpoint['best_val_loss'] = BEST_VAL_LOSS # 更新 checkpoint 中的最佳损失
        best_pt_path = os.path.join(WEIGHTS_DIR, "best.pt")
        torch.save(checkpoint, best_pt_path)
        print(f"  🎉 New best model found! Saved to {best_pt_path}")
        
    # --- 每 5 个 epoch，可视化一次预测结果 ---
    if (epoch + 1) % 5 == 0:
        print(f"  📊 Visualized predictions on test image")
        viz_dir = os.path.join(WEIGHTS_DIR, "viz")
        os.makedirs(viz_dir, exist_ok=True)

        # 1) 加载图片
        import cv2
        img_bgr = cv2.imread(TEST_IMAGE_PATH)
        if img_bgr is None:
            raise FileNotFoundError(f"TEST_IMAGE_PATH not found: {TEST_IMAGE_PATH}")
        
        # 非等比例直接拉伸
        img_resized = cv2.resize(img_bgr, (600, 600), interpolation=cv2.INTER_LINEAR)  

        # 2) 可视化保存叠框结果（会回到这张 800x600 的坐标系上绘制）
        save_path = os.path.join(viz_dir, f"epoch_{epoch+1:03d}.png")
        visualize_movenet(
            model=model_fpn,
            image=img_resized,               # 或者传入 np.ndarray(BGR/RGB 都行，这里内部按 RGB 处理显示)
            device=device,
            decoder=decode_movenet_outputs, # 直接用你已有的解码器
            input_size=192,
            stride=8,                       # 若模型把 P3 上采样到1/4，这里改成 4
            topk_centers=1,                 # 先只取最强中心
            center_thresh=0.25,
            keypoint_thresh=0.05,
            draw_bbox=True,
            draw_skeleton=True,
            draw_on_orig=True,              # 画在原图上
            draw_heatmaps=True,
            save_path=save_path,    # 或 None 并 show=True 直接显示
            show=False
        )
            
print("--- Training Finished ---")    


--- Starting Training ---

Epoch 1/100


  🟢 [Training] lr: 0.000100 : 100%|██████████| 16220/16220 [04:55<00:00, 54.82it/s, center=43.66, hm=95.31, offsets=6882.29, regs=32410.39]
  🟡 [Validating] : 100%|██████████| 340/340 [00:05<00:00, 61.68it/s, ct=0.706919, hm=0.248512, off=84.057418, pck=11.28%, reg=567.797134, tot=1220.607111]


  📜 Epoch 1/100 average loss: 3.5900
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 2/100


  🟢 [Training] lr: 0.000050 : 100%|██████████| 16220/16220 [04:48<00:00, 56.31it/s, center=32.65, hm=11.61, offsets=3946.55, regs=27308.93]
  🟡 [Validating] : 100%|██████████| 340/340 [00:06<00:00, 50.88it/s, ct=0.688918, hm=0.238077, off=81.188602, pck=10.76%, reg=542.902588, tot=1167.920774]


  📜 Epoch 2/100 average loss: 3.4351
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 3/100


  🟢 [Training] lr: 0.000100 : 100%|██████████| 16220/16220 [04:51<00:00, 55.60it/s, center=31.74, hm=11.33, offsets=3875.22, regs=26219.12]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 46.65it/s, ct=0.671145, hm=0.233503, off=80.145960, pck=13.20%, reg=529.805649, tot=1140.661907]


  📜 Epoch 3/100 average loss: 3.3549
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 4/100


  🟢 [Training] lr: 0.000100 : 100%|██████████| 16220/16220 [04:49<00:00, 55.98it/s, center=31.36, hm=11.17, offsets=3844.81, regs=25593.73]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 46.93it/s, ct=0.668420, hm=0.231180, off=79.715498, pck=14.45%, reg=529.605000, tot=1139.825097]


  📜 Epoch 4/100 average loss: 3.3524
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 5/100


  🟢 [Training] lr: 0.000100 : 100%|██████████| 16220/16220 [04:49<00:00, 56.08it/s, center=31.16, hm=11.08, offsets=3825.33, regs=25222.87]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.78it/s, ct=0.663141, hm=0.230199, off=79.367148, pck=14.54%, reg=526.820862, tot=1133.902211]


  📜 Epoch 5/100 average loss: 3.3350
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt
  📊 Visualized predictions on test image

Epoch 6/100


  🟢 [Training] lr: 0.000100 : 100%|██████████| 16220/16220 [04:51<00:00, 55.72it/s, center=31.06, hm=11.02, offsets=3816.24, regs=24987.40]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.66it/s, ct=0.662178, hm=0.228614, off=79.138452, pck=14.68%, reg=527.886556, tot=1135.802358]


  📜 Epoch 6/100 average loss: 3.3406
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 7/100


  🟢 [Training] lr: 0.000099 : 100%|██████████| 16220/16220 [04:46<00:00, 56.68it/s, center=30.92, hm=10.98, offsets=3809.99, regs=24795.22]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 45.06it/s, ct=0.654661, hm=0.228206, off=78.949096, pck=14.43%, reg=529.151944, tot=1138.135852]


  📜 Epoch 7/100 average loss: 3.3475
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 8/100


  🟢 [Training] lr: 0.000099 : 100%|██████████| 16220/16220 [04:48<00:00, 56.19it/s, center=30.86, hm=10.93, offsets=3802.31, regs=24595.43]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 43.61it/s, ct=0.655297, hm=0.227600, off=78.784068, pck=15.38%, reg=522.620069, tot=1124.907103]


  📜 Epoch 8/100 average loss: 3.3086
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 9/100


  🟢 [Training] lr: 0.000099 : 100%|██████████| 16220/16220 [04:50<00:00, 55.90it/s, center=30.81, hm=10.91, offsets=3796.86, regs=24476.32]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 45.25it/s, ct=0.653749, hm=0.226525, off=78.636585, pck=15.08%, reg=519.870103, tot=1119.257067]


  📜 Epoch 9/100 average loss: 3.2919
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 10/100


  🟢 [Training] lr: 0.000098 : 100%|██████████| 16220/16220 [04:58<00:00, 54.37it/s, center=30.69, hm=10.88, offsets=3795.29, regs=24319.43]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 42.58it/s, ct=0.654503, hm=0.226674, off=78.485995, pck=15.68%, reg=518.867667, tot=1117.102509]


  📜 Epoch 10/100 average loss: 3.2856
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt
  📊 Visualized predictions on test image

Epoch 11/100


  🟢 [Training] lr: 0.000098 : 100%|██████████| 16220/16220 [04:46<00:00, 56.53it/s, center=30.69, hm=10.86, offsets=3790.36, regs=24193.75]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.70it/s, ct=0.650639, hm=0.226193, off=78.449866, pck=15.46%, reg=520.358190, tot=1120.043079]


  📜 Epoch 11/100 average loss: 3.2942
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 12/100


  🟢 [Training] lr: 0.000098 : 100%|██████████| 16220/16220 [04:54<00:00, 55.03it/s, center=30.59, hm=10.85, offsets=3787.01, regs=24105.17]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 45.46it/s, ct=0.651415, hm=0.226204, off=78.370030, pck=15.95%, reg=522.753983, tot=1124.755614]


  📜 Epoch 12/100 average loss: 3.3081
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 13/100


  🟢 [Training] lr: 0.000097 : 100%|██████████| 16220/16220 [04:49<00:00, 56.10it/s, center=30.54, hm=10.83, offsets=3780.17, regs=23985.45]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.35it/s, ct=0.649337, hm=0.225954, off=78.281934, pck=15.73%, reg=521.433879, tot=1122.024985]


  📜 Epoch 13/100 average loss: 3.3001
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 14/100


  🟢 [Training] lr: 0.000097 : 100%|██████████| 16220/16220 [04:48<00:00, 56.15it/s, center=30.52, hm=10.83, offsets=3776.45, regs=23893.63]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.73it/s, ct=0.649652, hm=0.225686, off=78.250483, pck=15.63%, reg=517.849880, tot=1114.825584]


  📜 Epoch 14/100 average loss: 3.2789
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  🎉 New best model found! Saved to ./outputs/movenet/best.pt

Epoch 15/100


  🟢 [Training] lr: 0.000096 : 100%|██████████| 16220/16220 [04:48<00:00, 56.13it/s, center=30.53, hm=10.83, offsets=3775.63, regs=23792.64]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 42.89it/s, ct=0.648805, hm=0.225543, off=78.210124, pck=16.40%, reg=521.367825, tot=1121.820127]


  📜 Epoch 15/100 average loss: 3.2995
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  📊 Visualized predictions on test image

Epoch 16/100


  🟢 [Training] lr: 0.000095 : 100%|██████████| 16220/16220 [04:47<00:00, 56.46it/s, center=30.51, hm=10.81, offsets=3771.41, regs=23693.75]
  🟡 [Validating] : 100%|██████████| 340/340 [00:08<00:00, 42.49it/s, ct=0.650102, hm=0.225398, off=78.107334, pck=15.15%, reg=522.000938, tot=1122.984713]


  📜 Epoch 16/100 average loss: 3.3029
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 17/100


  🟢 [Training] lr: 0.000095 : 100%|██████████| 16220/16220 [04:46<00:00, 56.54it/s, center=30.47, hm=10.80, offsets=3769.40, regs=23644.88]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 43.96it/s, ct=0.645728, hm=0.225001, off=78.099929, pck=15.80%, reg=520.846692, tot=1120.664047]


  📜 Epoch 17/100 average loss: 3.2961
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 18/100


  🟢 [Training] lr: 0.000094 : 100%|██████████| 16220/16220 [04:54<00:00, 55.01it/s, center=30.45, hm=10.80, offsets=3768.65, regs=23544.09]
  🟡 [Validating] : 100%|██████████| 340/340 [00:08<00:00, 42.04it/s, ct=0.646685, hm=0.225229, off=78.064404, pck=16.05%, reg=521.687961, tot=1122.312239]


  📜 Epoch 18/100 average loss: 3.3009
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 19/100


  🟢 [Training] lr: 0.000093 : 100%|██████████| 16220/16220 [04:54<00:00, 55.06it/s, center=30.43, hm=10.79, offsets=3768.62, regs=23450.19]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 45.24it/s, ct=0.647459, hm=0.225418, off=78.021966, pck=15.57%, reg=525.462846, tot=1129.820534]


  📜 Epoch 19/100 average loss: 3.3230
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 20/100


  🟢 [Training] lr: 0.000092 : 100%|██████████| 16220/16220 [04:49<00:00, 55.99it/s, center=30.42, hm=10.79, offsets=3764.65, regs=23382.52]
  🟡 [Validating] : 100%|██████████| 340/340 [00:07<00:00, 44.19it/s, ct=0.646402, hm=0.225351, off=78.029534, pck=15.66%, reg=523.907374, tot=1126.716038]


  📜 Epoch 20/100 average loss: 3.3139
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt
  📊 Visualized predictions on test image

Epoch 21/100


  🟢 [Training] lr: 0.000091 : 100%|██████████| 16220/16220 [04:51<00:00, 55.61it/s, center=30.39, hm=10.78, offsets=3763.94, regs=23314.30]
  🟡 [Validating] : 100%|██████████| 340/340 [00:08<00:00, 42.25it/s, ct=0.648435, hm=0.225190, off=78.050951, pck=15.76%, reg=524.754231, tot=1128.433037]


  📜 Epoch 21/100 average loss: 3.3189
  🎯 Saved last checkpoint to ./outputs/movenet/last.pt

Epoch 22/100


  🟢 [Training] lr: 0.000091 :  23%|██▎       | 3683/16220 [01:06<03:47, 55.03it/s, center=6.90, hm=2.45, offsets=854.25, regs=5281.43]


KeyboardInterrupt: 