In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import wandb
from easydict import EasyDict
from cvmt.utils import (load_yaml_params, nested_dict_to_easydict)

from cvmt.inference.inference import (load_pretrained_model_eval_mode, predict_image,
                                      img_coord_2_cartesian_coord, translate_landmarks,
                                      rotate_landmarks, plot_landmarks,
                                      classify_by_mcnamara_and_franchi,
                                      plot_image_and_vertebral_landmarks,
                                     )
from cvmt.ml.utils import download_wandb_model_checkpoint

import torch
from typing import *
import numpy as np
import pandas as pd
import random

from PIL import Image

import glob


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
os.chdir("../../")

In [4]:
!source configs/.env

## Read params

In [5]:
CONFIG_PARAMS_PATH = "configs/params.yaml"

In [6]:
params: EasyDict = nested_dict_to_easydict(
    load_yaml_params(CONFIG_PARAMS_PATH)
)

## Download the checkpoint of the model

In [7]:
checkpoint_path, model_id = download_wandb_model_checkpoint(
    wandb_checkpoint_uri= params.VERIFY.WANDB_CHECKPOINT_REFERENCE_NAME
)
print(checkpoint_path)

[34m[1mwandb[0m: Downloading large artifact model-lmbw0bqa:v69, 100.36MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.4


./artifacts/model-lmbw0bqa:v69/model.ckpt


## Load Model

In [8]:
use_pretrain = True

task_config = params.TRAIN.V_LANDMARK_TASK
task_id = task_config.TASK_ID

loss_name = params.TRAIN.LOSS_NAME
model_params = params.MODEL.PARAMS
transforms_params = params.INFERENCE.TRANSFORMS

mcnamara_args = params.INFERENCE.MCNAMARA.ARGS
concavity_thresh = mcnamara_args.concavity_thresh
ant_pos_thresh = mcnamara_args.ant_pos_thresh
sup_inf_thresh = mcnamara_args.sup_inf_thresh
rect_thresh_min = mcnamara_args.rect_thresh_min
rect_thresh_max = mcnamara_args.rect_thresh_max

In [9]:
mcnamara_args

{'concavity_thresh': 1.0,
 'ant_pos_thresh': 0.95,
 'sup_inf_thresh': 0.95,
 'rect_thresh_min': 0.95,
 'rect_thresh_max': 1.05}

In [10]:
model, device = load_pretrained_model_eval_mode(
    model_params=model_params,
    use_pretrain=use_pretrain,
    checkpoint_path=checkpoint_path,
    task_id=task_id,
    loss_name=loss_name,
)

  rank_zero_warn(


## Load the table for pixel to cm ratios

In [11]:
pixel_to_cm_ratio_table = pd.read_csv(
    "data/test_dataset/pixel_to_cm_ratios.csv",
    delimiter=";",
    header=None,
    names=['ratio'],
)

pixel_to_cm_ratio_table['filename'] = [str(i)+'.jpg' for i in np.arange(1, len(pixel_to_cm_ratio_table)+1)]

## Load input and pass to model

In [12]:
img_dir = "data/test_dataset"
stages = []

for row in pixel_to_cm_ratio_table.to_dict('records'):
    # read the pixel to cm ratio and the filename
    pixel_to_cm_ratio = row['ratio']
    filename = row['filename']
    # open the image
    image = Image.open(os.path.join(img_dir, filename))
    # get the landmarks using the model
    rescaled_landmarks_coords = predict_image(
        image=image,
        model=model,
        task_id=task_id,
        transforms_params=transforms_params,
        device=device,
    )
    # get the bine age maturity class
    stage = classify_by_mcnamara_and_franchi(
        rescaled_landmarks_coords,
        pixel_to_cm_ratio,
        concavity_thresh,
        ant_pos_thresh,
        sup_inf_thresh,
        rect_thresh_min,
        rect_thresh_max,
    )
    # store in a list
    stages.append(
        {
            'filename': filename,
            'stage': stage,
        }
    )


stages = pd.DataFrame(stages)

In [13]:
stages

Unnamed: 0,filename,stage
0,1.jpg,cs4
1,2.jpg,undefined
2,3.jpg,cs5
3,4.jpg,cs3
4,5.jpg,cs6
5,6.jpg,cs5
6,7.jpg,cs1
7,8.jpg,cs4
8,9.jpg,cs4
9,10.jpg,cs4


In [14]:
stages['stage'].value_counts()

cs4          12
cs3          11
cs1          11
cs5           7
cs6           6
cs2           4
undefined     1
Name: stage, dtype: int64

# write stages to disk

In [15]:
stages.to_csv("artifacts/stages.csv", sep=';')