In [1]:
import sys;sys.path.insert(0, '..')
from src.RMSE import BatchRMSE
import torch
import os
from src.UNETMS import UNETMS
import src.core as core
import src.guassian as guassian


cwd = os.getcwd()
FCN8_model_dir = os.path.join(cwd, '..', 'models', 'UNETMS_FINAL_NO_AUG')

run_dir = os.path.join(FCN8_model_dir, 'run')
model_dir = os.path.join(FCN8_model_dir, 'model')


print(f"Using {core.TorchDevice} device")
torch.set_default_device(core.TorchDevice)

loaded_model = UNETMS()
saved_model_path = model_path = os.path.join(model_dir, 'model_20240407_161313_15')
checkpoint = torch.load(saved_model_path, map_location=core.TorchDevice)
loaded_model.load_state_dict(checkpoint)
loaded_model.eval()

rmse = BatchRMSE()

predicted_heatmaps = None
single_sample = None
for i, sample in enumerate(core.TestDataSet):
    single_sample = sample
    image, heatmaps = sample
    image, heatmaps = image.to(core.TorchDevice), heatmaps.to(core.TorchDevice)

    ## We need to unsqueeze add the batch dimension to the image [1, 3, 96, 96]
    image = image.unsqueeze(0)
    # print("image.shape", image.shape)

    predicted_heatmaps = loaded_model(image)

    pred_heatmaps = predicted_heatmaps.cpu().detach()

    ## We need to remove the extra batch dimensions from the prediction [24, 96, 96]
    pred_heatmaps = pred_heatmaps.squeeze(0)

    predicted_keypoints = guassian.heatmaps_to_keypoints_CoM(pred_heatmaps)

    ground_heatmaps = heatmaps.cpu().detach()

    gt_keypoints = guassian.heatmaps_to_keypoints_CoM(ground_heatmaps)

    rmse.add_pred_error(gt_keypoints, predicted_keypoints)


print(f"All Keypoints RMSE {rmse.get_all_keypoints_RMSE()}")



Using cpu device
All Keypoints RMSE 160.98668095669527


In [2]:
for i in range(24):
    try:
        print(f"All Keypoints RMSE {rmse.get_keypoint_RMSE(i)}")
    except:
        print(f"No RMSE for {i}th keypoint")

All Keypoints RMSE 10.357762228136226
All Keypoints RMSE 8.159668784978283
All Keypoints RMSE 9.540467705527524
All Keypoints RMSE 14.292934237973325
All Keypoints RMSE 12.573647476429311
All Keypoints RMSE 12.13224201113212
All Keypoints RMSE 9.964056083825977
All Keypoints RMSE 34.62132612872578
All Keypoints RMSE 7.310550494508979
All Keypoints RMSE 13.891667165010054
All Keypoints RMSE 849.0335837940518
All Keypoints RMSE 127.73339388608845
All Keypoints RMSE 13.757418815084224
All Keypoints RMSE 18.7142091849039
All Keypoints RMSE 7.910704442541396
All Keypoints RMSE 10.484529011655711
All Keypoints RMSE 9.353355631837948
All Keypoints RMSE 8.5987419608658
All Keypoints RMSE 12.071528750314783
All Keypoints RMSE 13.659487087738146
No RMSE for 20th keypoint
No RMSE for 21th keypoint
No RMSE for 22th keypoint
No RMSE for 23th keypoint
