In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import wandb
from easydict import EasyDict
from cvmt.ml.trainer import create_dataloader, SingletaskTrainLandmarks, mean_radial_error
from cvmt.utils import (load_yaml_params, nested_dict_to_easydict)

from cvmt.ml.models import MultiTaskLandmarkUNetCustom

import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.chdir("../../")

In [4]:
CONFIG_PARAMS_PATH = "configs/params.yaml"

In [5]:
params: EasyDict = nested_dict_to_easydict(
    load_yaml_params(CONFIG_PARAMS_PATH)
)

In [6]:
task_config = params.TRAIN.SINGLE_TASK
task_id = task_config.TASK_ID
batch_size = task_config.BATCH_SIZE
shuffle = task_config.SHUFFLE

In [7]:
# val dataloader
val_dataloader = create_dataloader(
    task_id=task_id,
    batch_size=batch_size,
    split='val',
    shuffle=shuffle,
    params=params,
)

In [8]:
user = "sm-data-science"
project = params.WANDB.INIT.project

In [9]:
# load the best model
api = wandb.Api()
path = f'{user}/{project}'
runs = api.runs(path=path)
latest_run = runs[0]

In [10]:
latest_run.id

'6gvcg0xa'

In [11]:
checkpoint_reference = "sm-data-science/cephal-landmark-detection/model-6gvcg0xa:v2"

In [12]:
artifact = api.artifact(checkpoint_reference)

In [13]:
artifact_dir = artifact.download()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [14]:
artifact_dir

'./artifacts/model-6gvcg0xa:v2'

In [15]:
checkpoint_path = artifact_dir+"/model.ckpt"
print(checkpoint_path)

./artifacts/model-6gvcg0xa:v2/model.ckpt


### use a pytorch lighning module for model class mother

See the issue below for why a pytorch lightning model cannot be loaded from checkpoint with hparams saved
and how to use pytorch lighning module to enable using a model that is defined outside a pl module.

https://github.com/Lightning-AI/lightning/issues/3629#issue-707536217

In [16]:
use_pretrain = True

model_params = params.MODEL.PARAMS
model = MultiTaskLandmarkUNetCustom(**model_params)
pl_model = SingletaskTrainLandmarks(model=model,)

if use_pretrain:
    model = pl_model.load_from_checkpoint(checkpoint_path,).model

model.eval()
model.double()

for batch in val_dataloader:
    images, labels = batch['image'], batch['v_landmarks']
    task_id = 3
    images.double()
    # Pass images through the model
    with torch.no_grad():
        predictions = model(images, task_id=task_id)
        
    print(mean_radial_error(preds=predictions, targets=labels))

  rank_zero_warn(


357.91337534914555
382.03686157416655
385.4165009602833
382.36701108210025
378.38431261992946
372.00592946383875


In [17]:
use_pretrain = False

model_params = params.MODEL.PARAMS
model = MultiTaskLandmarkUNetCustom(**model_params)
pl_model = SingletaskTrainLandmarks(
    model=model,
)

if use_pretrain:
    model = pl_model.load_from_checkpoint(checkpoint_path, ).model

model.eval()
model.double()

for batch in val_dataloader:
    images, labels = batch['image'], batch['v_landmarks']
    task_id = 3
    images.double()
    # Pass images through the model
    with torch.no_grad():
        predictions = model(images, task_id=task_id)
        
    print(mean_radial_error(preds=predictions, targets=labels))

428.2069990576174
418.47022914944444
398.7372133209992
412.6866796642769
410.0188256831929
364.20077849558703


## check data

In [18]:
for batch in val_dataloader:
    images, labels = batch['image'], batch['v_landmarks']
    break

In [19]:
sample = images[0,...]

In [20]:
sample.shape

torch.Size([1, 256, 256])

In [21]:
sample.mean()

tensor(83.6887, dtype=torch.float64)

In [22]:
sample.std()

tensor(40.9124, dtype=torch.float64)