# 關鍵點偵測模型訓練

## 載入資料

In [None]:
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
from shared.functions import *

keypoints_data_src = "via_proj/via_project_13Aug2025_15h48m06s.json"

In [None]:
import json

# Load your VIA project JSON file content (as a string or from a file)
with open(keypoints_data_src, 'r') as f:
    via_data = json.load(f)

# Mapping from file IDs to filenames
file_id_to_name = {fid: fdata.get('fname', '') for fid, fdata in via_data['file'].items()}

# Function to get keypoints for a specified image filename
def get_keypoints_for_image(filename):
    # Find associated file ID for the given filename
    file_ids = [fid for fid, name in file_id_to_name.items() if name == filename]
    if not file_ids:
        return None  # No matching file found
    keypoints = []
    for meta in via_data['metadata'].values():
        if meta['vid'] == file_ids[0]:
            xy = meta.get('xy', [])
            if len(xy) == 3:
                # xy format: [1, x, y] — take x, y
                keypoints.append((xy[1], xy[2]))
    return keypoints

def get_image(filename):
    img = cv2.imread(filename)
    return img

from typing import List, Tuple
import numpy as np
# If using cv2/image libraries, you can import cv2 as well for actual resizing.

def resize_image_and_keypoints(
    image: np.ndarray,
    keypoints: List[Tuple[float, float]],
    new_width: int,
    new_height: int
) -> Tuple[np.ndarray, List[Tuple[float, float]]]:
    """
    Resize input image and update keypoints to match the new dimensions.
    
    Args:
        image (np.ndarray): Original image array.
        keypoints (list of (x, y)): List of keypoints (floats).
        new_width (int): Target image width.
        new_height (int): Target image height.

    Returns:
        resized_image (np.ndarray): The resized image array.
        resized_keypoints (list of (x, y)): Scaled-updated keypoints.
    """
    orig_height, orig_width = image.shape[:2]
    scale_x = new_width / orig_width
    scale_y = new_height / orig_height

    # If OpenCV is available:
    # import cv2
    resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
    
    # Rescale keypoints
    resized_keypoints = [(x * scale_x, y * scale_y) for (x, y) in keypoints]
    return resized_image, resized_keypoints

In [None]:
from ultralytics import YOLO

# 模型與圖片路徑
model_path = "models/yolo_finetuned/best.pt"

# 只保留這些類別 ID（根據 data.yaml 順序）
allowed_classes = [1]  # 只要床單

# 載入模型
yolo_model_finetuned = YOLO(model_path)

In [None]:
def extract_mask_compare(image_path):
    image_name = os.path.basename(image_path)
    # 推論
    results = yolo_model_finetuned(image_path, task="segment")[0]

    # 原圖
    orig_img = cv2.imread(image_path)
    h, w = orig_img.shape[:2]

    # 空白遮罩
    mask_all = np.zeros((h, w), dtype=np.uint8)
    for r in results:
        if r.masks is None:
            continue
        masks = r.masks.data.cpu().numpy()     # [N, H_pred, W_pred]
        classes = r.boxes.cls.cpu().numpy()    # [N] 物件的類別 ID
        for m, cls_id in zip(masks, classes):
            if int(cls_id) not in allowed_classes:
                continue  # 跳過不在清單內的類別
            m = (m * 255).astype(np.uint8)
            m = cv2.resize(m, (w, h), interpolation=cv2.INTER_NEAREST)
            mask_all = cv2.bitwise_or(mask_all, m)
    masked_image = orig_img.copy()
    masked_image[mask_all==0] = 0
    return mask_all

In [None]:
import zipfile

img_arr = []
keypoints_img_arr = []
rgb_img_arr = []
rgb_img_orig_arr = []
realsense_path = "realsense/realsense_data/realsense_camera/"
realsense_depth_path = "realsense/realsense_data/depth_scaled/"
orig_hw = None
for f in os.listdir(realsense_path):
    if f[:6] == "color_":
        fnumber = f[6:]; fnumber = fnumber[:-4]
        depth_f = "depth_raw_" + fnumber + ".npy"
        color_f = "color_" + fnumber + ".png"
        depth_color_f = "depth_color_" + fnumber + ".png"
        # Example usage:
        depth_map = np.load(realsense_depth_path + depth_f)
        img = depth_map_to_image(depth_map)
        img =  cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        # Now you can save with cv2.imwrite or display with OpenCV/Matplotlib
        color_img = get_image(realsense_path+color_f)
        orig_hw = color_img.shape[:2]
        # img = get_image("realsense_camera/"+depth_color_f)
        mask = extract_mask_compare(realsense_path+color_f)

        # only use example where mask is detected.
        if np.sum(mask) > 0:
            masked_color_img = color_img.copy()
            masked_color_img[mask==0] = 0
            # mask the depth image
            img[mask==0] = 0
            orig_keypoints = get_keypoints_for_image(color_f)
            rgb_img_orig_arr.append(color_img)

            img, keypoints = resize_image_and_keypoints(img, orig_keypoints, 128, 128)
            color_img, keypoints = resize_image_and_keypoints(color_img, orig_keypoints, 128, 128)
            keypoints = [[kp[1], kp[0]] for kp in keypoints]
            #kp image
            kp_img = np.zeros((128, 128))
            for kp in keypoints:
                kp_img[int(kp[0]), int(kp[1])] = 1

            img_arr.append(img)
            rgb_img_arr.append(color_img)
            keypoints_img_arr.append(kp_img)
img_arr = np.array(img_arr)
rgb_img_arr = np.array(rgb_img_arr)
rgb_img_orig_arr = np.array(rgb_img_orig_arr)
keypoints_img_arr = np.array(keypoints_img_arr)

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.transforms.functional as TF
import random

class KeypointDataset(Dataset):
    def __init__(self, images, rgb_images, rgb_origs, keypoints, transform=None):
        self.images = images.astype(np.float32)
        self.rgb_images = rgb_images.astype(np.float32)
        self.rgb_orig = rgb_origs.astype(np.float32)
        self.keypoints = keypoints.astype(np.float32)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        rgb_img = self.rgb_images[idx]
        kp = self.keypoints[idx]  # shape (4, 2)
        img = np.transpose(img, (2, 0, 1))  # channels first
        rgb_img = np.transpose(rgb_img, (2, 0, 1))
        rgb_orig = self.rgb_orig[idx]
        sample = {'image': torch.from_numpy(img), 'keypoints': torch.from_numpy(kp), 'rgb_image': torch.from_numpy(rgb_img), "rgb_orig": torch.from_numpy(rgb_orig)}
        if self.transform:
            sample = self.transform(sample)
        return sample

import torch
import torchvision.transforms.functional as TF
import random

class RandomRotateFlip:
    """
    Randomly applies:
    - A rotation by any angle in [0, 360)
    - Optionally, a horizontal flip with 50% chance after rotation
    """
    def __call__(self, sample):
        image, rgb_image, keypoints, rgb_orig = sample['image'], sample["rgb_image"], sample['keypoints'], sample["rgb_orig"]
        # image: (C, H, W)
        # keypoints: (N, H, W) or (H, W)

        # --- Random rotation ---
        angle = random.uniform(0, 360)
        image = TF.rotate(image, angle, interpolation=TF.InterpolationMode.BILINEAR)
        rgb_image = TF.rotate(rgb_image, angle, interpolation=TF.InterpolationMode.BILINEAR)
        # For keypoints as heatmaps, use same rotate (assume keypoints is Tensor [N,H,W] or [H,W])
        # If N, treat each as a channel
        if keypoints.ndim == 3:
            keypoints = TF.rotate(keypoints, angle, interpolation=TF.InterpolationMode.BILINEAR)
        else:
            keypoints = TF.rotate(keypoints.unsqueeze(0), angle, interpolation=TF.InterpolationMode.BILINEAR).squeeze(0)

        # --- Random flip after rotation ---
        if random.random() < 0.5:
            image = TF.hflip(image)
            rgb_image = TF.hflip(rgb_image)
            keypoints = TF.hflip(keypoints)
        if random.random() < 0.5:
            image = TF.vflip(image)
            rgb_image = TF.vflip(rgb_image)
            keypoints = TF.vflip(keypoints)

        return {'image': image, 'rgb_image': rgb_image, 'keypoints': keypoints, "rgb_orig": rgb_orig}

In [None]:
rotate_transform = RandomRotateFlip()

# Create the full dataset without transform
full_dataset = KeypointDataset(img_arr, rgb_img_arr, rgb_img_orig_arr, keypoints_img_arr, transform=None)

# Split indices for train/test
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_indices, test_indices = torch.utils.data.random_split(range(len(full_dataset)), [train_size, test_size])

# Create train and test datasets with/without transform
train_dataset = torch.utils.data.Subset(KeypointDataset(img_arr, rgb_img_arr, rgb_img_orig_arr,keypoints_img_arr, transform=rotate_transform), train_indices)
test_dataset = torch.utils.data.Subset(KeypointDataset(img_arr, rgb_img_arr, rgb_img_orig_arr,keypoints_img_arr, transform=None), test_indices)

trainloader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)
testloader = DataLoader(test_dataset, batch_size=8, shuffle=False, pin_memory=True)

## 測試關鍵點跟圖片的正確性

In [None]:
pair = full_dataset.__getitem__(9)
img = pair["image"].numpy().copy() / 255
img = np.transpose(img, (1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

rgb_img = pair["rgb_image"].numpy().copy() / 255
rgb_img = np.transpose(rgb_img, (1, 2, 0))
rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR)

kp = pair["keypoints"].numpy()
print(img.shape, kp.shape)
for i in range(kp.shape[0]):
    for j in range(kp.shape[1]):
        if kp[i,j] > 0.1:
            cv2.circle(img, (j, i), 1, (0,0,255), -1)
plt.imshow(img)
plt.show()
plt.imshow(rgb_img)
plt.show()

## 訓練迴圈

In [None]:
load_model = False

In [None]:
from models.utils import *

In [None]:
def spatial_klloss(pred_map, target_map, eps=1e-8):
    # pred_map: after spatial softmax, (B, 1, H, W)
    # target_map: one-hot or few-hot, (B, H, W)
    B, _, H, W = pred_map.shape
    pred = pred_map.view(B, -1) + eps  # avoid log(0)
    target = target_map.view(B, -1) + eps
    pred_log = pred.log()
    target = target / target.sum(dim=1, keepdim=True)  # ensure sum-to-1; safe for multi-keypoint
    return (target * (target.log() - pred_log)).sum(dim=1).mean()

def kl_heatmap_loss(pred_hm, gt_hm, mask=None, reduction='mean'):
    """
    pred_hm: (B, 1, H, W) tensor, model output (must be positive, not all zeros)
    gt_hm:   (B, 1, H, W) tensor, ground-truth (should be positive, not all zeros)
    mask:    (B, 1, H, W) optional mask (1=valid, 0=ignored) or None
    reduction: 'mean', 'sum', or 'none'
    Returns: scalar loss
    """
    B, _, H, W = pred_hm.shape
    # Flatten
    pred_flat = pred_hm.view(B, -1)
    gt_flat = gt_hm.view(B, -1)

    # Force positive and normalize to sum=1 for both (prob dists)
    pred_probs = pred_flat.clamp(min=1e-8)
    pred_probs = pred_probs / pred_probs.sum(dim=1, keepdim=True)
    gt_probs = gt_flat.clamp(min=1e-8)
    gt_probs = gt_probs / gt_probs.sum(dim=1, keepdim=True)

    kl_div = F.kl_div(pred_probs.log(), gt_probs, reduction='none').sum(dim=1)  # KL per sample

    if reduction == 'mean':
        return kl_div.mean()
    elif reduction == 'sum':
        return kl_div.sum()
    else:
        return kl_div  # shape (B,)
    
def batch_gaussian_blur(x, kernel_size=5, sigma=2.0):
    """
    Apply Gaussian blur to a batch of heatmaps and normalize each so the max is 1.
    Args:
        x: Tensor [B, H, W] or [B, 1, H, W]
    Returns:
        Tensor with same shape, blurred and with peak 1 per sample
    """
    unsqueeze = False
    if x.dim() == 3:  # [B, H, W]
        x = x.unsqueeze(1)
        unsqueeze = True
    
    blurred = TF.gaussian_blur(x, kernel_size=[kernel_size, kernel_size], sigma=[sigma, sigma])
    max_vals = blurred.amax(dim=[2, 3], keepdim=True)
    max_vals[max_vals == 0] = 1.0  # Avoid division by zero
    normalized = blurred / max_vals

    if unsqueeze:
        normalized = normalized.squeeze(1)
    return normalized

def batch_entropy(pred_heatmaps):
    """
    pred_heatmaps: [B, C, H, W]
    Returns: [B] entropy per image
    """
    # Flatten spatial dimensions (and optionally channels) for softmax
    B, C, H, W = pred_heatmaps.shape
    flat = pred_heatmaps.view(B, -1)                # [B, C*H*W]
    probs = torch.softmax(flat, dim=1)              # normalize to sum=1 per image
    entropy = -(probs * torch.log(probs + 1e-8)).sum(dim=1)  # [B]
    return entropy

In [None]:
import torch.optim as optim
import torch.nn as nn
from models.unet import UNet
# from models.yolo_cnn import EnhancedYoloKeypointNet
from models.yolo_vit import HybridKeypointNet
from ultralytics import YOLO
import time

# optimization
from torch.amp import autocast, GradScaler
torch.backends.cudnn.benchmark = True
scaler = GradScaler()

# create device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# yolo cnn
# yolo11 = YOLO('yolo11l-seg.pt')  # Or yolo11m-seg.pt, yolo11x-seg.pt, etc.
# backbone_seq = yolo11.model.model[:12]
# backbone = YoloBackbone(backbone_seq, selected_indices=[0,1,2,3,4,5,6,7,8,9,10,11])
# input_dummy = torch.randn(1, 3, 512, 512)
# with torch.no_grad():
#     feats = backbone(input_dummy)
# print("Feature shapes:", [f.shape for f in feats])
# in_channels_list = [f.shape[1] for f in feats]

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# keypoint_net = EnhancedYoloKeypointNet(backbone, in_channels_list)
# model = keypoint_net
# for param in model.backbone.parameters():
#     param.requires_grad = False
# model = model.to(device)

# yolo vit
yolo_model = YOLO('yolov8l.pt')
backbone_seq = yolo_model.model.model[:10]
backbone = YoloBackbone(backbone_seq, selected_indices=[0,1,2,3,4,5,6,7,8,9])
input_dummy = torch.randn(1, 3, 128, 128)
with torch.no_grad():
    feats = backbone(input_dummy)
in_channels_list = [f.shape[1] for f in feats]
keypoint_net = HybridKeypointNet(backbone, in_channels_list)
model = keypoint_net
model = model.to(device)

# unet
# model = UNet(in_channels=3, out_channels=4).to(device)

compiled_model = torch.compile(model)
# load pretrained model
compiled_model.load_state_dict(torch.load('models/keypoint_model_vit.pth', map_location=device))
compiled_model.eval()

if not load_model:
    optimizer = optim.AdamW(compiled_model.parameters(), lr=1e-5)

    for epoch in range(300):
        time_start = time.time()
        compiled_model.train()
        running_loss = 0.0

        for batch in trainloader:
            images = batch["image"].to(device)
            keypoints = batch["keypoints"].to(device)
            optimizer.zero_grad()

            with autocast("cuda", dtype=torch.float16):      # AMP context, not forcing .half()
                outputs = compiled_model(images)
                keypoints_blur = batch_gaussian_blur(keypoints, kernel_size=31, sigma=3)
                
                # active learning: Uncertainty Sampling using entropy as the uncertainty metric
                entropies = batch_entropy(outputs)
                k = images.size(0) // 2
                topk_vals, topk_idx = torch.topk(entropies, k, largest=True)  # highest entropy first
                selected_outputs = outputs[topk_idx]
                selected_keypoints_blur = keypoints_blur[topk_idx]

                # calculate loss
                loss = kl_heatmap_loss(selected_outputs, selected_keypoints_blur.unsqueeze(1))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item() * images.size(0)
        print(f'Epoch {epoch+1}: Loss {running_loss / len(train_dataset):.4f} time seconds:, {time.time() - time_start}')

    # save the model
    torch.save(compiled_model.state_dict(), 'models/keypoint_model_vit_depth.pth')
else:
    
    compiled_model.load_state_dict(torch.load('models/keypoint_model_vit_depth.pth', map_location=device))
    compiled_model.eval()

## 模型結果分析

In [None]:
# Evaluate on the validation set
compiled_model.eval()
val_loss = 0.0
iter = 0
with torch.no_grad():
    for batch in testloader:
        images = batch['image'].to(device)
        rgb_images = batch['rgb_image'].to(device)
        keypoints = batch['keypoints'].to(device)
        rgb_origs = batch['rgb_orig'].to(device)
        with autocast("cuda", dtype=torch.float16): 
            outputs = compiled_model(images)
            keypoints_blur = batch_gaussian_blur(keypoints, kernel_size=31, sigma=3)
            loss = kl_heatmap_loss(outputs, keypoints_blur.unsqueeze(1))
        # render the predicted keypoints on the image
        for img, kp, rgb_img, rgb_orig in zip(images.cpu().numpy(), outputs.cpu().numpy(), rgb_images.cpu().numpy(), rgb_origs.cpu().numpy()):
            img = np.transpose(img, (1, 2, 0))
            rgb_img = rgb_orig
            # Convert RGB to BGR for OpenCV
            rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2BGR)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            kp = kp[0,:,:]
            peaks = thresholded_locations(kp, 0.003)
            for p in peaks:
                i,j = p
                cv2.circle(rgb_img, (int(j * orig_hw[1]/128), int(i * orig_hw[0]/128)), 10, (255,0,0), -1)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            rgb_img = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB)
            # 撰寫測試集結果於 output 資料集
            # Resize img to original size using orig_hw
            cv2.imwrite(f'results/keypoints_{iter}.png', rgb_img)
            iter += 1
        val_loss += loss.item() * images.size(0)
    print(f'Validation Loss: {val_loss / len(test_dataset):.4f}')


## 模型視覺化

In [None]:
from torchview import draw_graph

# Suppose `model` is your nn.Module, and x is a sample input tensor
model_graph = draw_graph(model, input_data=torch.randn((8,3,128,128)), expand_nested=True)
model_graph.visual_graph.render(filename='architecture_full', format='png')
# or to view inline in a Jupyter notebook:
display(model_graph.visual_graph)