In [2]:
# Animal Pose Estimation Demo with AP-10K HRNet Model
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
import warnings
warnings.filterwarnings('ignore')

In [None]:
# # Install required packages
# %pip install mmpose mmcv mmdet opencv-python matplotlib


In [3]:
# Import MMPose modules
try:
    from mmpose.apis import init_model, inference_topdown
    from mmpose.utils import register_all_modules
    from mmdet.apis import init_detector, inference_detector
    import mmpose
    print(f"MMPose version: {mmpose.__version__}")
except ImportError as e:
    print(f"Import error: {e}")
    print("Please install mmpose and dependencies")


MMPose version: 1.3.2


In [7]:
# Model configuration
# Model configuration
CONFIG_URL = (
    "https://raw.githubusercontent.com/open-mmlab/mmpose/main/"
    "configs/animal_2d_keypoint/topdown_heatmap/ap10k/"
    "td-hm_hrnet-w48_8xb64-210e_ap10k-256x256.py"
)

MODEL_URL = "https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w48_ap10k_256x256-0e4a3462_20220830.pth"

# Detection model for animal detection
# Detection model for animal detection
DET_CONFIG_URL = (
    "https://raw.githubusercontent.com/open-mmlab/mmdetection/main/"
    "configs/yolox/yolox_x_8x8_300e_coco.py"
)
DET_MODEL_URL = "https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_x_8x8_300e_coco/yolox_x_8x8_300e_coco_20211126_140254-1ef88d67.pth"

# Create directories
os.makedirs('models', exist_ok=True)
os.makedirs('configs', exist_ok=True)

# Download model files
def download_file(url, filename):
    if not os.path.exists(filename):
        print(f"Downloading {filename}...")
        urlretrieve(url, filename)
        print(f"Downloaded {filename}")
    else:
        print(f"{filename} already exists")

# Download config and model files
pose_config = "configs/hrnet_w48_ap10k_256x256.py"
pose_model = "models/hrnet_w48_ap10k_256x256.pth"
det_config = "configs/yolox_x_8x8_300e_coco.py"
det_model = "models/yolox_x_8x8_300e_coco.pth"

download_file(CONFIG_URL, pose_config)
download_file(MODEL_URL, pose_model)
download_file(DET_CONFIG_URL, det_config)
download_file(DET_MODEL_URL, det_model)


Downloading configs/hrnet_w48_ap10k_256x256.py...
Downloaded configs/hrnet_w48_ap10k_256x256.py
Downloading models/hrnet_w48_ap10k_256x256.pth...


HTTPError: HTTP Error 404: Not Found

In [None]:
# Initialize models
print("Initializing models...")
register_all_modules()

# Initialize detection model
det_model = init_detector(det_config, det_model, device='cpu')

# Initialize pose estimation model 
pose_model = init_model(pose_config, pose_model, device='cpu')

print("Models initialized successfully!")


In [None]:
# Load and process the cow image
COW_IMG = 'data/bonting-identification/samples/ear-tags/detection/sample_cow_ear_tag_midrange_clear.png'

# Check if image exists
if not os.path.exists(COW_IMG):
    print(f"Image not found: {COW_IMG}")
    print("Please make sure the image path is correct")
    # Create a dummy image for demonstration
    import numpy as np
    dummy_img = np.ones((256, 256, 3), dtype=np.uint8) * 128
    COW_IMG = "dummy_cow.jpg"
    cv2.imwrite(COW_IMG, dummy_img)
    print(f"Created dummy image: {COW_IMG}")

# Load image
image = cv2.imread(COW_IMG)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

print(f"Image shape: {image.shape}")
plt.figure(figsize=(10, 8))
plt.imshow(image_rgb)
plt.title("Input Cow Image")
plt.axis('off')
plt.show()


In [None]:
# Run animal detection first
print("Running animal detection...")
det_results = inference_detector(det_model, COW_IMG)

# Extract bounding boxes for animals (assuming class 0 is person, but we'll use any detection)
# For COCO dataset: cow is class 21
bboxes = []
if hasattr(det_results, 'pred_instances'):
    # MMDetection v3.0+ format
    if len(det_results.pred_instances.bboxes) > 0:
        bboxes = det_results.pred_instances.bboxes.cpu().numpy()
        scores = det_results.pred_instances.scores.cpu().numpy()
        # Filter by confidence threshold
        valid_idx = scores > 0.3
        bboxes = bboxes[valid_idx]
else:
    # Legacy format
    for i, bbox_list in enumerate(det_results):
        if len(bbox_list) > 0:
            bboxes.extend(bbox_list[bbox_list[:, 4] > 0.3])

if len(bboxes) == 0:
    # If no detection, use full image
    h, w = image.shape[:2]
    bboxes = [[0, 0, w, h, 1.0]]
    print("No detection found, using full image")
else:
    print(f"Found {len(bboxes)} detection(s)")

print("Detection bboxes:", bboxes[:3])  # Show first 3 boxes


In [None]:
# Run pose estimation on detected animals
print("Running pose estimation...")
pose_results = inference_topdown(pose_model, COW_IMG, bboxes)

print(f"Pose estimation completed! Found {len(pose_results)} result(s)")

# Display results
if len(pose_results) > 0:
    for i, result in enumerate(pose_results):
        print(f"Result {i}: {len(result.pred_instances.keypoints[0])} keypoints detected")
        print(f"Keypoint scores: {result.pred_instances.keypoint_scores[0][:5]}...")  # Show first 5 scores


In [None]:
# Visualize results
def visualize_pose_results(image, pose_results, keypoint_threshold=0.3):
    """Visualize pose estimation results"""
    # AP-10K keypoint names (17 keypoints)
    keypoint_names = [
        'Left Eye', 'Right Eye', 'Nose', 'Neck', 'Root of Tail',
        'Left Shoulder', 'Left Elbow', 'Left Front Paw', 
        'Right Shoulder', 'Right Elbow', 'Right Front Paw',
        'Left Hip', 'Left Knee', 'Left Back Paw',
        'Right Hip', 'Right Knee', 'Right Back Paw'
    ]
    
    # Create a copy of the image for visualization
    vis_image = image.copy()
    
    # Colors for different keypoints (BGR format)
    colors = [
        (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
        (0, 255, 255), (128, 0, 0), (0, 128, 0), (0, 0, 128), (128, 128, 0),
        (128, 0, 128), (0, 128, 128), (255, 128, 0), (255, 0, 128), (128, 255, 0),
        (0, 255, 128), (128, 0, 255)
    ]
    
    for result in pose_results:
        keypoints = result.pred_instances.keypoints[0]  # Shape: [17, 3] for AP-10K
        keypoint_scores = result.pred_instances.keypoint_scores[0]  # Shape: [17]
        
        # Draw keypoints
        for i, ((x, y, v), score) in enumerate(zip(keypoints, keypoint_scores)):
            if score > keypoint_threshold and v > 0:  # Only draw visible keypoints with good confidence
                color = colors[i % len(colors)]
                cv2.circle(vis_image, (int(x), int(y)), 5, color, -1)
                cv2.putText(vis_image, f'{i}', (int(x)+5, int(y)-5), 
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
    
    return vis_image

# Visualize the results
if len(pose_results) > 0:
    vis_image = visualize_pose_results(image, pose_results)
    vis_image_rgb = cv2.cvtColor(vis_image, cv2.COLOR_BGR2RGB)
    
    # Display side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
    
    ax1.imshow(image_rgb)
    ax1.set_title("Original Image")
    ax1.axis('off')
    
    ax2.imshow(vis_image_rgb)
    ax2.set_title("Pose Estimation Results (AP-10K 17 keypoints)")
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print keypoint information
    print("\\nAP-10K Keypoint Mapping:")
    keypoint_names = [
        'Left Eye', 'Right Eye', 'Nose', 'Neck', 'Root of Tail',
        'Left Shoulder', 'Left Elbow', 'Left Front Paw', 
        'Right Shoulder', 'Right Elbow', 'Right Front Paw',
        'Left Hip', 'Left Knee', 'Left Back Paw',
        'Right Hip', 'Right Knee', 'Right Back Paw'
    ]
    for i, name in enumerate(keypoint_names):
        print(f"{i}: {name}")
else:
    print("No pose estimation results to visualize")
