In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import yaml
import tqdm
import json
import torch
import hydra
import pprint
import logging
import colorlog
import src.prepare  # noqa
import logging.config

import numpy as np
import pytorch_lightning as pl

from src.config import read_config
from hydra.utils import instantiate
from src.load import load_model_from_cfg
from omegaconf import OmegaConf, DictConfig
from hydra import initialize, initialize_config_module, initialize_config_dir, compose

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
LOGGING_CONFIG = {
    'version': 1,
    'disable_existing_loggers': False,
    'formatters': {
        'simple': {
            'format': '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s',
            'datefmt': '%d/%m/%y %H:%M:%S',
        },
        'colorlog': {
            '()': 'colorlog.ColoredFormatter',
            'format': '[%(white)s%(asctime)s%(reset)s] %(log_color)s%(levelname)s%(reset)s   %(message)s',
            'datefmt': '%d/%m/%y %H:%M:%S',
            'log_colors': {
                'DEBUG': 'purple',
                'INFO': 'blue',
                'WARNING': 'yellow',
                'ERROR': 'red',
                'CRITICAL': 'red',
            },
        },
    },
}

logging.config.dictConfig(LOGGING_CONFIG)

logger = logging.getLogger(__name__)

logger.info("logger has been configured.")

<div class="alert alert-info">

**source:** [hydra configurations in a notebook](https://github.com/facebookresearch/hydra/blob/main/examples/jupyter_notebooks/compose_configs_in_notebook.ipynb).

</div>

In [10]:
RUN_DIRS = [
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_16",
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_32",
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_64",
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_128",
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_256",
    "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats_512",
]

RUN_DIR = "/home/nadir/disk/codes/tmr-code/outputs/classifier_babel-classifier_guoh3dfeats"

In [5]:
with initialize(
    version_base=None,
    config_path="configs",
    # 
):
    # config = compose(overrides=["+db=mysql"])
    config = compose(
        config_name="evaluate-classifier",
        return_hydra_config=True,
        overrides=[f"run_dir={RUN_DIR}"]
    )
    
    from hydra.core.hydra_config import HydraConfig
    
    HydraConfig.instance().set_config(config)
    
print(OmegaConf.to_yaml(config, sort_keys=False))

hydra:
  run:
    dir: ${run_dir}
  sweep:
    dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
    subdir: ${hydra.job.num}
  launcher:
    _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
  sweeper:
    _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
    max_batch_size: null
    params: null
  help:
    app_name: ${hydra.job.name}
    header: '${hydra.help.app_name} is powered by Hydra.

      '
    footer: 'Powered by Hydra (https://hydra.cc)

      Use --hydra-help to view Hydra specific help

      '
    template: '${hydra.help.header}

      == Configuration groups ==

      Compose your configuration from those groups (group=option)


      $APP_CONFIG_GROUPS


      == Config ==

      Override anything in the config (foo.bar=value)


      $CONFIG


      ${hydra.help.footer}

      '
  hydra_help:
    template: 'Hydra (${hydra.runtime.version})

      See https://hydra.cc for more info.


      == Flags ==

      $FLAGS_HELP


      == Config

In [19]:
for run_dir in RUN_DIRS:
    try:
        logger.info(f"[run_dir]: {run_dir}")
        print(f"[model-variant]: {run_dir.split('_')[-1]}")
        
        device = config.device
        # run_dir = config.run_dir
        examples = config.examples
        ckpt_name = config.ckpt

        # save_dir = os.path.join(run_dir, "segmentation-evaluation")
        # os.makedirs(save_dir, exist_ok=True)

        # NOTE: moved up here in order to use the segmentation config for the dataset
        # defined in the config file rather than the config used to train the model
        dataset = instantiate(config.data, mode="classic", split="test")

        # NOTE: will load the config used to train the model
        cfg = read_config(run_dir)

        pl.seed_everything(cfg.seed)

        logger.info("[model]: loading")
        model = load_model_from_cfg(cfg, ckpt_name, eval_mode=True, device=device)

        logger.info(f"[dataset.mode]: {dataset.mode}")

        dataloader = instantiate(
            cfg.dataloader,
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            shuffle=False,
        )

        logger.info(f"[#dataloader]: {len(dataset)}")
        logger.info(f"[dataloader.batch_size]: {dataloader.batch_size}")

        model = model.eval()
        
        all_predictions, all_labels = [], []
                
        from src.model import ClassifierModel

        with torch.no_grad():
            for index, batch in tqdm.tqdm(iterable=enumerate(dataloader), total=len(dataloader), desc="[evaluate-classification]"):
                batch["motion_x_dict"]["x"] = batch["motion_x_dict"]["x"].to(device)
                batch["motion_x_dict"]["mask"] = batch["motion_x_dict"]["mask"].to(device)
                
                outputs = model(batch, None)
                
                labels = ClassifierModel.get_targets(batch, None)
                predictions = (torch.sigmoid(outputs) > 0.5).long().cpu().numpy()
                
                for prediction, label in zip(predictions, labels):
                    all_predictions.append(prediction)
                    all_labels.append(label)

        from src.model.metrics import accuracy_score, precision_score, recall_score, f1_score
        from sklearn.metrics import classification_report

        # TODO: compute balanced accuracy score
        accuracy = accuracy_score(all_predictions, all_labels)
        precision = precision_score(all_labels, all_predictions, average='macro')
        recall = recall_score(all_labels, all_predictions, average='macro')
        f1 = f1_score(all_labels, all_predictions, average='macro')
        report = classification_report(all_labels, all_predictions, zero_division=0)

        print(f"[accuracy]: {(accuracy * 100):.02f}%")
        print(f"[precision]: {(precision * 100):.02f}%")
        print(f"[recall]: {(recall * 100):.02f}%")
        print(f"[f1]: {(f1 * 100):.02f}%")
        print(report)
        
        print("--- --- ---")
        
        from sklearn.metrics import confusion_matrix as generate_confusion_matrix

        confusion_matrix = generate_confusion_matrix(all_labels, all_predictions)
        
        print("--- --- ---")
        
        print(confusion_matrix)
        
        print("--- --- ---\n\n\n\n")
    except Exception as exception:
        print(f"[error]: {exception}")

Global seed set to 1234


[model-variant]: 16


[evaluate-classification]: 100%|██████████| 655/655 [01:19<00:00,  8.27it/s]
Global seed set to 1234


[accuracy]: 82.53%
[precision]: 79.59%
[recall]: 76.99%
[f1]: 78.07%
              precision    recall  f1-score   support

         0.0       0.74      0.63      0.68      1542
         1.0       0.86      0.90      0.88      3691

    accuracy                           0.83      5233
   macro avg       0.80      0.77      0.78      5233
weighted avg       0.82      0.83      0.82      5233

--- --- ---
--- --- ---
[[ 979  563]
 [ 351 3340]]
--- --- ---




[error]: Error(s) in loading state_dict for Sequential:
	size mismatch for 0.weight: copying a param with shape torch.Size([64, 256]) from checkpoint, the shape in current model is torch.Size([32, 256]).
	size mismatch for 0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for 2.weight: copying a param with shape torch.Size([1, 64]) from checkpoint, the shape in current model is torch.Size([1, 32]).


Global seed set to 1234


[model-variant]: 64


[evaluate-classification]: 100%|██████████| 655/655 [01:19<00:00,  8.21it/s]


[accuracy]: 81.37%
[precision]: 77.66%
[recall]: 77.05%
[f1]: 77.34%
              precision    recall  f1-score   support

         0.0       0.69      0.67      0.68      1542
         1.0       0.86      0.88      0.87      3691

    accuracy                           0.81      5233
   macro avg       0.78      0.77      0.77      5233
weighted avg       0.81      0.81      0.81      5233

--- --- ---
--- --- ---
[[1026  516]
 [ 459 3232]]
--- --- ---






Global seed set to 1234


[model-variant]: 128


[evaluate-classification]: 100%|██████████| 655/655 [01:19<00:00,  8.25it/s]
Global seed set to 1234


[accuracy]: 81.10%
[precision]: 77.28%
[recall]: 77.18%
[f1]: 77.23%
              precision    recall  f1-score   support

         0.0       0.68      0.68      0.68      1542
         1.0       0.87      0.87      0.87      3691

    accuracy                           0.81      5233
   macro avg       0.77      0.77      0.77      5233
weighted avg       0.81      0.81      0.81      5233

--- --- ---
--- --- ---
[[1043  499]
 [ 490 3201]]
--- --- ---




[model-variant]: 256


[evaluate-classification]: 100%|██████████| 655/655 [01:21<00:00,  8.08it/s]


[accuracy]: 81.92%
[precision]: 78.56%
[recall]: 76.86%
[f1]: 77.61%
              precision    recall  f1-score   support

         0.0       0.71      0.65      0.68      1542
         1.0       0.86      0.89      0.87      3691

    accuracy                           0.82      5233
   macro avg       0.79      0.77      0.78      5233
weighted avg       0.82      0.82      0.82      5233

--- --- ---
--- --- ---
[[ 995  547]
 [ 399 3292]]
--- --- ---






Global seed set to 1234


[model-variant]: 512


[evaluate-classification]: 100%|██████████| 655/655 [01:21<00:00,  8.08it/s]


[accuracy]: 81.88%
[precision]: 78.49%
[recall]: 76.91%
[f1]: 77.61%
              precision    recall  f1-score   support

         0.0       0.71      0.65      0.68      1542
         1.0       0.86      0.89      0.87      3691

    accuracy                           0.82      5233
   macro avg       0.78      0.77      0.78      5233
weighted avg       0.81      0.82      0.82      5233

--- --- ---
--- --- ---
[[ 999  543]
 [ 405 3286]]
--- --- ---




