# 6DRepNet360 Head Pose Estimation Example

This notebook demonstrates how to use the 6DRepNet360 model for head pose estimation on a single image.

## Steps:
1. Check GPU availability and download pre-trained model
2. Import required libraries and modules
3. Load and configure the model
4. Run inference on an image
5. Visualize results

In [None]:
# Step 1: Check GPU availability
!nvidia-smi

In [None]:
# Download pre-trained model
!wget https://cloud.ovgu.de/s/TewGC9TDLGgKkmS/download -O 6DRepNet360_Full-Rotation_300W_LP+Panoptic.pth

In [2]:
# Step 2: Import required libraries
import os
import sys
import cv2
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# Check if model file exists
model_filename = '6DRepNet360_Full-Rotation_300W_LP+Panoptic.pth'
model_exists = os.path.exists(model_filename) or os.path.exists(os.path.join('..', model_filename))

if not model_exists:
    print(f"Model file '{model_filename}' not found.")
    print("Please run the download cell above.")

In [3]:
# Import project modules
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
sixd_dir_path = os.path.join(parent_dir, 'sixdrepnet360')

if sixd_dir_path not in sys.path:
    sys.path.insert(0, sixd_dir_path)

import utils
from test import SixDRepNet360

In [4]:
# Step 3: Load and configure the model
%matplotlib inline

# Initialize model (ResNet-50 based)
model = SixDRepNet360(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3])

# Load pre-trained weights (check both current dir and parent dir)
model_filename = '6DRepNet360_Full-Rotation_300W_LP+Panoptic.pth'
model_paths = [
    model_filename,  # Current directory (examples/)
    os.path.join('..', model_filename)  # Parent directory (project root)
]

saved_model_path = None
for path in model_paths:
    if os.path.exists(path):
        saved_model_path = path
        break

if saved_model_path is None:
    raise FileNotFoundError(f"Model file '{model_filename}' not found in current or parent directory")

print(f"Loading model from: {saved_model_path}")

saved_state_dict = torch.load(saved_model_path, map_location='cpu')
model.load_state_dict(saved_state_dict)

# Move to device and set to evaluation mode
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

print(f"Model loaded successfully on {device}")

Loading model from: 6DRepNet360_Full-Rotation_300W_LP+Panoptic.pth
Model loaded successfully on cuda:0
Model loaded successfully on cuda:0


In [5]:
# Step 4: Prepare image transformations
transformations = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Step 5: Load image and run inference
# Replace with your image path
image_path = '/path/to/your/image.jpg'  

# Load and preprocess image
img_pil = Image.open(image_path).convert('RGB')
input_tensor = transformations(img_pil)
input_tensor = input_tensor.unsqueeze(0).to(device)

# Run inference
with torch.no_grad():
    R_pred = model(input_tensor)

# Convert rotation matrix to Euler angles
euler_angles = utils.compute_euler_angles_from_rotation_matrices(R_pred) * 180 / np.pi
pitch_deg = euler_angles[0, 0].cpu().item()
yaw_deg = euler_angles[0, 1].cpu().item()
roll_deg = euler_angles[0, 2].cpu().item()

print("--- Inference Results ---")
print(f"Yaw: {yaw_deg:.2f}°")
print(f"Pitch: {pitch_deg:.2f}°")
print(f"Roll: {roll_deg:.2f}°")

--- Inference Results ---
Yaw: -63.21°
Pitch: 166.03°
Roll: 179.67°


In [None]:
# Step 6: Visualize results
cv2_img = np.array(img_pil)
cv2_img = cv2_img[:, :, ::-1].copy()  # RGB to BGR

# Draw pose axes on the image
utils.draw_axis(cv2_img, yaw_deg, pitch_deg, roll_deg)

# Display result
plt.figure(figsize=(10, 8))
plt.imshow(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB))
plt.axis('off')
plt.title(f"Head Pose Estimation\nYaw: {yaw_deg:.1f}°, Pitch: {pitch_deg:.1f}°, Roll: {roll_deg:.1f}°")
plt.show()