In [1]:
%cd ..
%load_ext autoreload
%autoreload 2

/nfs/homedirs/elfleins/Developer/baseline


In [2]:
from pathlib import Path

import torch
import yaml
from pytorch_lightning import Trainer

from uncertainty_est.models import DeepEnsemble
from uncertainty_est.models import load_checkpoint, load_model, resolve_model_checkpoint
from uncertainty_est.archs.arch_factory import get_arch
from uncertainty_est.data.dataloaders import get_dataloader, get_dataset

In [24]:
checkpoint_dir = Path("./thesis_logs/cifar10/CE Baseline/")
output_folder = Path("./thesis_logs/cifar10/Ensemble/version_0/")
output_folder.mkdir(exist_ok=True, parents=True)

version_dirs = checkpoint_dir.glob("**/version_*")
checkpoints = [resolve_model_checkpoint(d) for d in version_dirs]
checkpoints

[PosixPath('thesis_logs/cifar10/CE Baseline/version_3/epoch=62-step=88577.ckpt'),
 PosixPath('thesis_logs/cifar10/CE Baseline/version_0/epoch=48-step=68893.ckpt'),
 PosixPath('thesis_logs/cifar10/CE Baseline/version_2/epoch=47-step=67487.ckpt'),
 PosixPath('thesis_logs/cifar10/CE Baseline/version_4/epoch=60-step=85765.ckpt'),
 PosixPath('thesis_logs/cifar10/CE Baseline/version_1/epoch=33-step=47803.ckpt')]

In [25]:
ebm, config = load_checkpoint(checkpoints[0])

| Wide-Resnet 16x8


In [26]:
de = DeepEnsemble(**config["model_config"], num_models=len(checkpoints))

for i, ckpt in enumerate(checkpoints[1:], 1):
    backbone = get_arch(de.hparams.arch_name, de.hparams.arch_config)
    # Remove the first module name in order to load backbone model
    strip_sd = {
        k.split(".", 1)[1]: v for k, v in torch.load(ckpt)["state_dict"].items()
    }
    print(i)
    backbone.load_state_dict(strip_sd)
    de.models[i] = backbone
len(de.models)

| Wide-Resnet 16x8
| Wide-Resnet 16x8
1
| Wide-Resnet 16x8
2
| Wide-Resnet 16x8
3
| Wide-Resnet 16x8
4


5

In [27]:
id_dl = get_dataloader(config["dataset"], "test")
dl = get_dataloader("mnist", "test", data_shape=id_dl.dataset.data_shape)

Files already downloaded and verified


In [28]:
de.eval()
de.cuda()
gt, pred = de.get_gt_preds(id_dl)

(torch.argmax(pred, 1) == gt).float().mean()

100%|██████████| 313/313 [00:41<00:00,  7.51it/s]


tensor(0.9335)

In [29]:
de.eval_ood(id_dl, [("LSUN", dl)])

313it [00:41,  7.49it/s]
0it [00:00, ?it/s]

output with shape [1, 32, 32] doesn't match the broadcast shape [3, 32, 32]





{}

In [30]:
de.ood_val_datasets = None
t = Trainer(max_epochs=0, default_root_dir="temp")
t.fit(de, dl)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores
Set SLURM handle signals.

  | Name     | Type       | Params
----------------------------------------
0 | backbone | WideResNet | 11.0 M
1 | models   | ModuleList | 54.8 M
----------------------------------------
65.8 M    Trainable params
0         Non-trainable params
65.8 M    Total params
263.008   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

1

In [31]:
t.save_checkpoint(output_folder / "last.ckpt")

In [32]:
config["model_config"] = dict(de.hparams)
config["model_name"] = "DeepEnsemble"

In [33]:
config

{'batch_size': 32,
 'checkpoint_config': {'mode': 'max', 'monitor': 'val/acc'},
 'data_shape': [32, 32, 3],
 'dataset': 'cifar10',
 'db_collection': 'cifar10',
 'earlystop_config': {'mode': 'max', 'monitor': 'val/acc', 'patience': 10},
 'log_dir': './thesis_logs',
 'model_config': {'arch_name': 'wrn',
  'arch_config': {'depth': 16,
   'input_channels': 3,
   'num_classes': 10,
   'widen_factor': 8},
  'learning_rate': 0.0001,
  'momentum': 0.9,
  'weight_decay': 0.0,
  'data_shape': [32, 32, 3],
  'ood_val_datasets': ['celeb-a', 'cifar100'],
  'num_models': 5},
 'model_name': 'DeepEnsemble',
 'num_classes': 10,
 'ood_dataset': None,
 'output_folder': 'cifar10/ce_baseline_16_8',
 'overwrite': 64,
 'seed': 2324234,
 'test_ood_datasets': ['lsun',
  'textures',
  'cifar100',
  'svhn',
  'celeb-a',
  'uniform_noise',
  'gaussian_noise',
  'constant',
  'svhn_unscaled'],
 'trainer_config': {'benchmark': True, 'gpus': 1, 'max_epochs': 100}}

In [34]:
with (output_folder / "config.yaml").open("w") as f:
    f.write(yaml.dump(config))

In [35]:
de, config = load_model(output_folder)
len(de.models)

| Wide-Resnet 16x8


5

In [36]:
de.eval()
de.cuda()
gt, pred = de.get_gt_preds(id_dl)

(torch.argmax(pred, 1) == gt).float().mean()

100%|██████████| 313/313 [00:41<00:00,  7.51it/s]


tensor(0.9335)

In [22]:
de.eval_ood(id_dl, [("LSUN", dl)])

313it [00:40,  7.74it/s]
313it [00:40,  7.74it/s]


{('LSUN', 'Variance', 'AUROC'): 54.333543500000005,
 ('LSUN', 'Variance', 'AUPR'): 46.981769481630884,
 ('LSUN', 'Variance', 'iAUROC'): 54.333543500000005,
 ('LSUN', 'Variance', 'iAUPR'): 66.78694661980643,
 ('LSUN', 'Entropy', 'AUROC'): 75.048826,
 ('LSUN', 'Entropy', 'AUPR'): 67.05583850562287,
 ('LSUN', 'Entropy', 'iAUROC'): 75.04882599999999,
 ('LSUN', 'Entropy', 'iAUPR'): 73.4901972956034}

In [9]:
de.ood_detect(dl)

100%|██████████| 313/313 [00:08<00:00, 38.17it/s]


{'Variance': array([ 9.197707 , 15.414769 ,  8.869913 , ..., 13.5622425, 11.11999  ,
        15.780851 ], dtype=float32)}

In [23]:
de.load_state_dict(torch.load(output_folder / "last.ckpt")["state_dict"])

<All keys matched successfully>