In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

import wandb
from easydict import EasyDict
from cvmt.ml.trainer import create_dataloader, SingletaskTraining, mean_radial_error, max_indices_4d_tensor
from cvmt.utils import (load_yaml_params, nested_dict_to_easydict)

from cvmt.ml.models import load_model

import torch
from typing import *
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D

In [None]:
os.chdir("../../")

In [None]:
!source configs/.env

In [None]:
CONFIG_PARAMS_PATH = "configs/params.yaml"

In [None]:
params: EasyDict = nested_dict_to_easydict(
    load_yaml_params(CONFIG_PARAMS_PATH)
)

In [None]:
if params.VERIFY.CHECKPOINT_PATH:
    # create the checkpoint path
    checkpoint_path = params.VERIFY.CHECKPOINT_PATH
elif params.TEST.WANDB_CHECKPOINT_REFERENCE_NAME:
    try:
        # setup wandb
        user = "sm-data-science"
        project = params.WANDB.INIT.project

        # load the best model
        api = wandb.Api()

        # create the checkpoint path
        checkpoint_reference = params.VERIFY.WANDB_CHECKPOINT_REFERENCE_NAME
        artifact = api.artifact(checkpoint_reference)
        artifact_dir = artifact.download()
        checkpoint_path = artifact_dir+"/model.ckpt"
    except Exception as e:
        print(e)
else:
    raise ValueError(
        "You have to define either `TEST.CHECKPOINT_PATH` or "
        "`TEST.WANDB_CHECKPOINT_REFERENCE_NAME` in the config/params.yaml"
)

print(checkpoint_path)

## Create data loaders

In [None]:
use_pretrain = True

task_config = params.TRAIN.V_LANDMARK_TASK
task_id = task_config.TASK_ID
batch_size = task_config.BATCH_SIZE

loss_name = params.TRAIN.LOSS_NAME
optim_params = params.TRAIN.OPTIMIZER
model_params = params.MODEL.PARAMS

n_images_to_plot = params.VERIFY.N_IMGAES_TO_PLOT

In [None]:
# val dataloader
train_dataloader = create_dataloader(
    task_id=task_id,
    batch_size=1,
    split='train',
    shuffle=False,
    params=params,
    sampler_n_samples=None,
)
# val dataloader
val_dataloader = create_dataloader(
    task_id=task_id,
    batch_size=1,
    split='val',
    shuffle=False,
    params=params,
    sampler_n_samples=None,
)

## Load the model
### use a pytorch lighning module for model class mother

See the issue below for why a pytorch lightning model cannot be loaded from checkpoint with hparams saved
and how to use pytorch lighning module to enable using a model that is defined outside a pl module.

https://github.com/Lightning-AI/lightning/issues/3629#issue-707536217

In [None]:
model = load_model(**model_params)

pl_module = SingletaskTraining(
    model=model,
    task_id=task_id,
    checkpoint_path=checkpoint_path,
    loss_name=loss_name,
)

if use_pretrain:
    model = pl_module.load_from_checkpoint(checkpoint_path,).model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
model.eval()
# model.double()

val_radial_errors = []

counter_perf = 1
counter_med = 1
counter_bad = 1

perf_samples = []
med_samples = []
bad_samples = []

for i, batch in enumerate(val_dataloader):
    images, targets = batch['image'], batch['v_landmarks']
    images = images.to(device)
    targets = targets.to(device)
    # Pass images through the model
    with torch.no_grad():
        predictions = model(images, task_id=task_id)
        
    mre = mean_radial_error(preds=predictions, targets=targets)
    mre = mre.item()
    val_radial_errors.append(mre)
    
    preds_coords = max_indices_4d_tensor(predictions)
    preds_coords = preds_coords.cpu().numpy()
    preds_coords = np.squeeze(preds_coords)
    
    targs_coords = max_indices_4d_tensor(targets)
    targs_coords = targs_coords.cpu().numpy()
    targs_coords = np.squeeze(targs_coords)
    
    landmarks_coords = {'preds': preds_coords, 'targets': targs_coords}
    image = images[0].cpu().numpy()
    
    if mre >= 9.6 and counter_bad <= n_images_to_plot:
        # store the sample
        print(f"bad - counter: {counter_bad} , mre: {mre}")
        bad_samples.append([image, landmarks_coords, mre,])
        counter_bad += 1

    if mre <= 9.6 and mre >= 1.22 and counter_med <= n_images_to_plot:
        # store the sample
        print(f"med - counter: {counter_med} , mre: {mre}")
        med_samples.append([image, landmarks_coords, mre,])
        counter_med += 1

    if mre < 1.22 and counter_perf <=n_images_to_plot:
        # store the sample
        print(f"perf - counter: {counter_perf} , mre: {mre}")
        perf_samples.append([image, landmarks_coords, mre,])
        counter_perf += 1


In [None]:
np.median(val_radial_errors), np.percentile(val_radial_errors, 25), np.percentile(val_radial_errors, 75), np.mean(val_radial_errors), np.std(val_radial_errors), 

## Visualize data

In [None]:
def plot_images_and_landmark_coords(items: List[Any], category: str='all'):
    
    # if user desires a specific category
    if category == 'all':
        categories = ['preds', 'targets']
    else:
        categories = [category]

    # Calculate the number of rows for subplots
    if len(items)> 1:
        rows = len(items) // 2
        fig, axs = plt.subplots(rows, 2, figsize=(16, 8*rows))
        axs = axs.flatten()
    else:
        fig, axs = plt.subplots(1, 1, figsize=(16,16))
        axs = [axs]

    for ax, item in zip(axs, items):
        image, landmarks, mre = item
        image = image.squeeze()
        target_landmarks = landmarks['targets']
        pred_landmarks = landmarks['preds']
        ax.imshow(image, cmap='gray',)
        if 'targets' in categories:
            for landmark in target_landmarks:
                # Assuming each landmark is a tuple of (x, y) coordinates
                ax.add_patch(patches.Circle((landmark[1], landmark[0]), radius=1, color='yellow'))
        if 'preds' in categories:
            for landmark in pred_landmarks:
                # Assuming each landmark is a tuple of (x, y) coordinates
                ax.add_patch(patches.Circle((landmark[1], landmark[0]), radius=1, color='cyan'))
        ax.set_title(f'MRE={mre}')  # Set your title here

    # Create a legend
    legend_elements = [Line2D([0], [0], marker='o', color='w', label='Target Landmarks',
                              markerfacecolor='yellow', markersize=10),
                       Line2D([0], [0], marker='o', color='w', label='Predicted Landmarks',
                              markerfacecolor='cyan', markersize=10)]
    
    # Place the legend on the axes
    for ax in axs:
        ax.legend(handles=legend_elements, loc='lower right')
        
    plt.tight_layout()
    plt.show()


In [None]:
plot_images_and_landmark_coords(perf_samples)

In [None]:
plot_images_and_landmark_coords(med_samples)

In [None]:
plot_images_and_landmark_coords(bad_samples)

***