# 關鍵點偵測模型訓練

## 載入資料

In [None]:
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

image_data_dir = "cloth_data_gen/output/images"
keypoint_data_dir = "cloth_data_gen/output/keypoints"

img_arr = []
keypoints_img_arr = []
for img_file in os.listdir(image_data_dir):
    if img_file.endswith('.png'):
        name = img_file.split('.')[0]
        keypoint_file = os.path.join(keypoint_data_dir, name + '.txt')
        image_path = os.path.join(image_data_dir, img_file)
        img = cv2.imread(image_path)
        keypoints = pd.read_csv(keypoint_file)
        pixels_coords = keypoints[['x_pixel', 'y_pixel']].values
        kimg = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
        karr = []
        # check if all pixels coordinates are within the image bounds
        if pixels_coords.shape[0] > 0 and np.all((pixels_coords[:, 0] >= 0) & (pixels_coords[:, 0] < img.shape[1]) & 
                                                  (pixels_coords[:, 1] >= 0) & (pixels_coords[:, 1] < img.shape[0])):
            for point in pixels_coords:
                kimg[int(point[1]), int(point[0])] = 255
                karr.append([int(point[0]), int(point[1])])
            # keypoints_img_arr.append(kimg)
            keypoints_img_arr.append(karr)
            img_arr.append(img)
img_arr = np.array(img_arr)
keypoints_img_arr = np.array(keypoints_img_arr)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class KeypointDataset(Dataset):
    def __init__(self, images, keypoints, transform=None):
        self.images = images.astype(np.float32)
        self.keypoints = keypoints.astype(np.float32)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]  # shape (400, 400, 3)
        kp = self.keypoints[idx]  # shape (4, 2)
        img = np.transpose(img, (2, 0, 1))  # channels first
        sample = {'image': torch.from_numpy(img), 'keypoints': torch.from_numpy(kp)}
        if self.transform:
            sample = self.transform(sample)
        return sample

## 訓練迴圈

In [None]:
load_model = False

In [None]:
from models.utils import *

In [None]:
import torch.optim as optim
import torch.nn as nn
from models.unet import UNet
# from models.yolo_cnn import EnhancedYoloKeypointNet
from models.yolo_vit import HybridKeypointNet
from ultralytics import YOLO
import time

# optimization
from torch.amp import autocast, GradScaler
torch.backends.cudnn.benchmark = True
scaler = GradScaler()

# Suppose images_arr: (n, 400, 400, 3), keypoints_arr: (n, 4, 2)
dataset = KeypointDataset(img_arr, keypoints_img_arr)

# split dataset into train and test
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
trainloader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True)
testloader = DataLoader(test_dataset, batch_size=8, shuffle=False, pin_memory=True)

# loss function

# criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(100.0))
def pairwise_distances(pred, gt):
    # pred: (B, K, 2), gt: (B, K, 2)
    # returns: (B, K, K), where d[b,i,j]=||pred[b,i]-gt[b,j]||
    diff = pred.unsqueeze(2) - gt.unsqueeze(1)  # (B, K, K, 2)
    dist = torch.norm(diff, dim=-1)     # (B, K, K)
    return dist

def unordered_keypoint_loss(pred, gt):
    # pred, gt: (B, K, 2)
    dist_matrix = pairwise_distances(pred, gt)  # (B, K, K)
    gt_to_pred = dist_matrix.min(dim=1)[0].mean(dim=1)  # mean over gt points
    pred_to_gt = dist_matrix.min(dim=2)[0].mean(dim=1)  # mean over pred points
    return 0.5 * (gt_to_pred + pred_to_gt).mean()  # final scalar loss

import torch
import torch.nn.functional as F

def greedy_keypoint_assignment_loss(pred, gt):
    # pred, gt: (B, K, 2)
    batch_size, K, _ = pred.shape
    losses = []

    for b in range(batch_size):
        dmat = torch.cdist(pred[b], gt[b], p=2)  # (K, K)
        avail_pred = torch.ones(K, dtype=torch.bool, device=pred.device)
        avail_gt = torch.ones(K, dtype=torch.bool, device=gt.device)
        total_loss = 0.0

        for _ in range(K):
            mask = avail_pred[:, None] & avail_gt[None, :]  # (K, K)
            masked_dmat = dmat.masked_fill(~mask, float('inf'))
            # Flatten and find the min
            flat_val, flat_idx = masked_dmat.view(-1).min(0)
            i, j = divmod(flat_idx.item(), K)
            total_loss += F.smooth_l1_loss(pred[b, i], gt[b, j], reduction='sum')
            avail_pred[i] = False
            avail_gt[j] = False
        losses.append(total_loss / K)
    return torch.stack(losses).mean()

# create device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# yolo cnn
# yolo11 = YOLO('yolo11l-seg.pt')  # Or yolo11m-seg.pt, yolo11x-seg.pt, etc.
# backbone_seq = yolo11.model.model[:12]
# backbone = YoloBackbone(backbone_seq, selected_indices=[0,1,2,3,4,5,6,7,8,9,10,11])
# input_dummy = torch.randn(1, 3, 512, 512)
# with torch.no_grad():
#     feats = backbone(input_dummy)
# print("Feature shapes:", [f.shape for f in feats])
# in_channels_list = [f.shape[1] for f in feats]

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# keypoint_net = EnhancedYoloKeypointNet(backbone, in_channels_list)
# model = keypoint_net
# for param in model.backbone.parameters():
#     param.requires_grad = False
# model = model.to(device)

# yolo vit
yolo11 = YOLO('yolo11l-seg.pt')  # Or yolo11m-seg.pt, yolo11x-seg.pt, etc.
backbone_seq = yolo11.model.model[:12]
backbone = YoloBackbone(backbone_seq, selected_indices=[0,1,2,3,4,5,6,7,8,9,10,11])
input_dummy = torch.randn(1, 3, 128, 128)
with torch.no_grad():
    feats = backbone(input_dummy)
in_channels_list = [f.shape[1] for f in feats]
keypoint_net = HybridKeypointNet(backbone, in_channels_list)
model = keypoint_net
for param in model.backbone.parameters():
    param.requires_grad = False
# for param in model.diffusion.vit.parameters():
#     param.requires_grad = False
model = model.to(device)

# unet
# model = UNet(in_channels=3, out_channels=4).to(device)

compiled_model = torch.compile(model)
if not load_model:
    optimizer = optim.AdamW(compiled_model.parameters(), lr=1e-5)

    for epoch in range(300):
        time_start = time.time()
        compiled_model.train()
        running_loss = 0.0

        for batch in trainloader:
            images = batch["image"].to(device)
            keypoints = batch["keypoints"].to(device)
            optimizer.zero_grad()

            with autocast("cuda", dtype=torch.float16):      # AMP context, not forcing .half()
                outputs = compiled_model(images)
                coords = soft_argmax(outputs)
                loss = greedy_keypoint_assignment_loss(coords, keypoints)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item() * images.size(0)
        print(f'Epoch {epoch+1}: Loss {running_loss / len(dataset):.4f} time seconds:, {time.time() - time_start}')

    # save the model
    torch.save(compiled_model.state_dict(), 'models/keypoint_model_vit.pth')
else:
    
    compiled_model.load_state_dict(torch.load('models/keypoint_mode_vit.pth', map_location=device))
    compiled_model.eval()

## 模型結果分析

In [None]:
# Evaluate on the validation set
compiled_model.eval()
val_loss = 0.0
iter = 0
with torch.no_grad():
    for batch in testloader:
        images = batch['image'].to(device)
        keypoints = batch['keypoints'].to(device)
        with autocast("cuda", dtype=torch.float16): 
            outputs = compiled_model(images)
        coords = soft_argmax(outputs)
        # render the predicted keypoints on the image
        for img, kp in zip(images.cpu().numpy(), coords.cpu().numpy()):
            img = np.transpose(img, (1, 2, 0))
            # Convert RGB to BGR for OpenCV
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            for i in range(kp.shape[0]):
                cv2.circle(img, (int(kp[i][0]), int(kp[i][1])), 1, (0,0,255), -1)
            # 撰寫測試集結果於 output 資料集
            cv2.imwrite(f'results/keypoints_{iter}.png', img)
            iter += 1
        loss = greedy_keypoint_assignment_loss(coords, keypoints)
        val_loss += loss.item() * images.size(0)
    print(f'Validation Loss: {val_loss / len(test_dataset):.4f}')


## 模型視覺化

In [None]:
from torchview import draw_graph

# Suppose `model` is your nn.Module, and x is a sample input tensor
model_graph = draw_graph(model, input_data=torch.randn((8,3,128,128)), expand_nested=True)
model_graph.visual_graph.render(filename='architecture_full', format='png')
# or to view inline in a Jupyter notebook:
display(model_graph.visual_graph)