# Inference pipeline

In [1]:
import torch
from resnet import ResNet18, InferenceDataset, load_checkpoint
from torchvision import transforms
import os
from tqdm import tqdm

In [None]:
image_dir = "data/raw/images"
label_dir = "data/raw/labels"
pseudo_label_dir = "data/raw/pseudo_labels"
output_dir = "data/raw/predictions"
checkpoint_path = "checkpoints/epoch_100.pth"
num_keypoints = 11
batch_size = 4

os.makedirs(output_dir, exist_ok=True)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet18(num_keypoints=num_keypoints)
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint['model'])

model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((1920, 1080)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
# Import images that don't have a corresponding label or pseudo_labels

image_paths = []

for filename in os.listdir(image_dir):
    if not filename.endswith(".jpg") and not filename.endswith(".png"):
        continue

    label_file = os.path.splitext(filename)[0] + '.txt'
    label_path = os.path.join(label_dir, label_file)
    pseudo_label_path = os.path.join(pseudo_label_dir, label_file)
    if os.path.exists(label_path) or os.path.exists(pseudo_label_path):
        continue

    image_paths.append(os.path.join(image_dir, filename))

print(f"Found {len(image_paths)} images without labels or pseudo-labels. Running inference...")

Found 190 images without labels or pseudo-labels. Running inference...


In [5]:
dataset = InferenceDataset(image_paths, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [6]:
# Run inference

all_keypoints = []
all_image_paths = []

with torch.no_grad():
    for images, paths in tqdm(dataloader):    # plural because they are batches
        images = images.to(device)
        outputs = model(images)
        batch_keypoints = outputs.cpu().numpy()

        # Post-process (0-1 to Pixels)
        for path, keypoint_set in zip(paths, batch_keypoints):
            # Instead of a for-loop, slice the array to get all x and y coordinates at once
            # Equivalent to: for i in range (0, len(keypoint_set), 2): x = keypoint_set[i], y = keypoint_set[i+1]
            x_coords = (keypoint_set[0::2] * 1920).astype(int)
            y_coords = (keypoint_set[1::2] * 1080).astype(int)
            keypoints = list(zip(x_coords, y_coords))

            all_keypoints.append(list(zip(x_coords, y_coords)))
            all_image_paths.append(path)

100%|██████████| 48/48 [00:08<00:00,  5.99it/s]


In [7]:
# Save the predictions and corresponding image paths

for img_path, pred in zip(all_image_paths, all_keypoints):
    img_name = os.path.basename(img_path)
    pred_file = os.path.splitext(img_name)[0] + '.txt'
    pred_path = os.path.join(output_dir, pred_file)
    with open(pred_path, 'w') as f:
        for point in pred:
            f.write(f"{point[0]} {point[1]}\n")