# Model refresh validation

Building on `03-app`, this notebook is used for validating a refreshed model on the QUT dataset.

**December 2023:**
* Generated M1-only dataset with `poetry run ichthywhat create-rls-species-dataset-from-api` on 2023-12-14 and uploaded to Kaggle.
* Did two training runs on Kaggle with `06a-model-refresh-training-kaggle.ipynb`:
  * `app-v2-20231219-200-epochs.pkl`: Exactly the same settings as before (v2 with 200 epochs). It may have underfit because now there's more data. Training loss was 0.573659 rather than 0.407830.
  * `app-v2-20231219-400-epochs.pkl`: Same v2 settings but with 400 epochs. The training loss is down to 0.440023, and the QUT accuracy is slightly higher. However, it's marginally lower than the 2022-01 model (accuracy@10: 0.281 versus 0.284). Given that there are many more species now (2399 versus 2167 before), the slight drop in accuracy is worth it.
* As an additional feasibility check of adding M2 species, did a training run with M1 + M2 species (v2 with 400 epochs and 20 frozen epochs). Training loss was ... QUT accuracy ...

**TODO**: fill in the blanks for M1 + M2

**TODO**: decide on deployment (M1 refreshed or M1 + M2?)

In [1]:
from fastai.learner import load_learner
import httpx

from ichthywhat import experiments
from ichthywhat.inference import OnnxWrapper
from ichthywhat.constants import ROOT_PATH, DEFAULT_DATA_PATH, DEFAULT_MODELS_PATH
from ichthywhat.training import export_learner_to_onnx, train_app_model

In [6]:
def load_qut_dataset(
    dataset_path=DEFAULT_DATA_PATH / "qut-cropped-controlled",
    species_json_url="https://raw.githubusercontent.com/yanirs/rls-data/master/output/species.json",
):
    """Load the QUT dataset file and update superseded names (assumes recent training data)."""
    all_species = httpx.get(species_json_url).json()
    superseded_to_current = {}
    for species in all_species:
        for superseded_name in species.get("superseded_names", []):
            superseded_to_current[superseded_name] = species["scientific_name"]
    paths = []
    labels = []
    for path in dataset_path.glob("*.png"):
        paths.append(path)
        label = " ".join(path.name.split("-")[:2]).capitalize()
        labels.append(superseded_to_current.get(label, label))
    return paths, labels


qut_paths, qut_labels = load_qut_dataset()
list(zip(qut_paths, qut_labels))[:5]

[(Path('/vagrant/data/qut-cropped-controlled/anampses-caeruleopunctatus-13.png'),
  'Anampses caeruleopunctatus'),
 (Path('/vagrant/data/qut-cropped-controlled/thalassoma-trilobatum-4.png'),
  'Thalassoma trilobatum'),
 (Path('/vagrant/data/qut-cropped-controlled/plotosus-lineatus-7.png'),
  'Plotosus lineatus'),
 (Path('/vagrant/data/qut-cropped-controlled/cirrhilabrus-scottorum-14.png'),
  'Cirrhilabrus scottorum'),
 (Path('/vagrant/data/qut-cropped-controlled/lutjanus-quinquelineatus-7.png'),
  'Lutjanus quinquelineatus')]

In [3]:
def test_model(learner_pkl_path):
    learner = load_learner(learner_pkl_path)
    learner_stats = experiments.test_learner(
        learner, qut_paths, qut_labels, show_grid=False
    )
    onnx_path = learner_pkl_path.parent / learner_pkl_path.name.replace(".pkl", ".onnx")
    export_learner_to_onnx(learner_pkl_path, onnx_path)
    onnx_stats = OnnxWrapper(onnx_path).evaluate(qut_paths, qut_labels)
    return dict(learner=learner_stats, onnx=onnx_stats)

In [4]:
test_model(DEFAULT_MODELS_PATH / "app-v2-20231219-200-epochs.pkl")

verbose: False, log level: Level.ERROR



{'learner': {'top_1_accuracy': 0.09460888057947159,
  'top_3_accuracy': 0.15010571479797363,
  'top_10_accuracy': 0.23678646981716156},
 'onnx': {'top_1_accuracy': 0.10940803382663848,
  'top_3_accuracy': 0.18128964059196617,
  'top_10_accuracy': 0.2774841437632135}}

In [5]:
test_model(DEFAULT_MODELS_PATH / "app-v2-20231219-400-epochs.pkl")

verbose: False, log level: Level.ERROR



{'learner': {'top_1_accuracy': 0.08615221828222275,
  'top_3_accuracy': 0.1532769501209259,
  'top_10_accuracy': 0.2452431321144104},
 'onnx': {'top_1_accuracy': 0.11469344608879492,
  'top_3_accuracy': 0.18551797040169132,
  'top_10_accuracy': 0.28118393234672306}}