# D3M Tutorial

In this tutorial, we will use D3M to monitor deteriorating shifts in the UCI Heart Disease dataset. The preprocessed dataset is available in ``data/uci_data/``.

In [1]:
import os
os.chdir('../')
import torch
import numpy as np

SEED = 57
np.random.seed(SEED)
torch.random.manual_seed(SEED)

<torch._C.Generator at 0x7f2527fc4850>

### Import data

We begin with importing data:

In [2]:
from data.uci_data import UCIDataset
data_dict = torch.load('data/uci_data/uci_heart_torch.pt')
uci_dict = {}
for k, data in data_dict.items():
    data = list(zip(*data))
    X, y = torch.stack(data[0]), torch.tensor(data[1], dtype=torch.int)
    
    # normalize
    min_ = torch.min(X, dim=0).values
    max_ = torch.max(X, dim=0).values
    X = (X - min_) / (max_ - min_)
    uci_dict[k] = UCIDataset(X, y)

We will use ``train`` to train the base model, ``valid`` to validate the base model. Then, ``valid`` will be used to train the distribution of i.i.d. disagreement rates Phi.

We will monitor on both ``train`` and ``iid_test`` in order to validate that our monitor is well-calibrated, i.e. it resists flagging in-distribution unseen samples. 

Finally, we monitor ``ood_test`` to assert that our monitor detects deteriorating changes from the dataset.

In [3]:
dataset_dict = {}
dataset_dict['train'] = uci_dict['train']
dataset_dict['valid'] = uci_dict['val']
dataset_dict['d3m_train'] = uci_dict['val']
dataset_dict['d3m_id'] = uci_dict['iid_test']
dataset_dict['d3m_ood'] = uci_dict['ood_test']

### Import D3M components

D3M involves two primary components: 
    
- the base model
- the monitor

The base model will depend on the data modality. For tabular data, we work with ``MLPModel``.  For images, we work with ``ConvModel``.
The monitor takes in a base model, training and validation datasets, and a configuration file. We parse these using ``hydra-core``.

In [4]:
import hydra
from omegaconf import DictConfig
from d3m import MLPModel, D3MBayesianMonitor
from experiments.utils import get_configs

In [5]:
hydra.initialize(config_path='../experiments/configs', version_base='1.2')
args = hydra.compose(config_name="uci_best")

In [6]:
model_config, train_config = get_configs(args)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
base_model = MLPModel(model_config, train_size=len(dataset_dict['train']))
monitor = D3MBayesianMonitor(
        model=base_model,
        trainset=dataset_dict['train'],
        valset=dataset_dict['valid'],
        train_cfg=train_config,
        device=device,
    )

### Base Classifier Training

We are now ready to train the model per our configurations. Simply run ``monitor.train_model``. 

In [8]:
ood_testloader = torch.utils.data.DataLoader(dataset_dict['d3m_ood'],
                                                     batch_size=train_config.batch_size,
                                                     shuffle=False,
                                                     num_workers=train_config.num_workers,
                                                     pin_memory=train_config.pin_memory)
output_metrics = monitor.train_model(tqdm_enabled=True, testloader=ood_testloader)

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch:  0, train loss: 19.4199	train accuracy: 0.5706




Epoch:  0, val loss: 0.6563	val accuracy: 0.6317


  2%|▏         | 1/50 [00:00<00:37,  1.30it/s]

Epoch:  0, ood test loss: 0.7521	ood test accuracy: 0.3264




Epoch:  1, train loss: 19.2373	train accuracy: 0.5537




Epoch:  1, val loss: 0.6396	val accuracy: 0.6295


  4%|▍         | 2/50 [00:01<00:26,  1.80it/s]

Epoch:  1, ood test loss: 0.7512	ood test accuracy: 0.3203




Epoch:  2, train loss: 19.0393	train accuracy: 0.6201




Epoch:  2, val loss: 0.6254	val accuracy: 0.6808


  6%|▌         | 3/50 [00:01<00:22,  2.06it/s]

Epoch:  2, ood test loss: 0.7352	ood test accuracy: 0.3776




Epoch:  3, train loss: 18.8860	train accuracy: 0.6297




Epoch:  3, val loss: 0.6071	val accuracy: 0.7355


  8%|▊         | 4/50 [00:01<00:20,  2.28it/s]

Epoch:  3, ood test loss: 0.7164	ood test accuracy: 0.4826




Epoch:  4, train loss: 18.7013	train accuracy: 0.6401




Epoch:  4, val loss: 0.5977	val accuracy: 0.7388


 10%|█         | 5/50 [00:02<00:19,  2.34it/s]

Epoch:  4, ood test loss: 0.6957	ood test accuracy: 0.5903




Epoch:  5, train loss: 18.5664	train accuracy: 0.6033




Epoch:  5, val loss: 0.5797	val accuracy: 0.7388


 12%|█▏        | 6/50 [00:02<00:18,  2.36it/s]

Epoch:  5, ood test loss: 0.6989	ood test accuracy: 0.5582




Epoch:  6, train loss: 18.3861	train accuracy: 0.6939




Epoch:  6, val loss: 0.5632	val accuracy: 0.7734


 14%|█▍        | 7/50 [00:03<00:18,  2.39it/s]

Epoch:  6, ood test loss: 0.6770	ood test accuracy: 0.5660




Epoch:  7, train loss: 18.2271	train accuracy: 0.7279




Epoch:  7, val loss: 0.5674	val accuracy: 0.7645


 16%|█▌        | 8/50 [00:03<00:17,  2.44it/s]

Epoch:  7, ood test loss: 0.6913	ood test accuracy: 0.6372




Epoch:  8, train loss: 18.0974	train accuracy: 0.6931




Epoch:  8, val loss: 0.5618	val accuracy: 0.7746


 18%|█▊        | 9/50 [00:03<00:16,  2.46it/s]

Epoch:  8, ood test loss: 0.6744	ood test accuracy: 0.5764




Epoch:  9, train loss: 17.9380	train accuracy: 0.7270




Epoch:  9, val loss: 0.5466	val accuracy: 0.7589


 20%|██        | 10/50 [00:04<00:16,  2.46it/s]

Epoch:  9, ood test loss: 0.6630	ood test accuracy: 0.6424




Epoch: 10, train loss: 17.7868	train accuracy: 0.7452




Epoch: 10, val loss: 0.5330	val accuracy: 0.7533


 22%|██▏       | 11/50 [00:04<00:16,  2.42it/s]

Epoch: 10, ood test loss: 0.6599	ood test accuracy: 0.6554




Epoch: 11, train loss: 17.6500	train accuracy: 0.7218




Epoch: 11, val loss: 0.5294	val accuracy: 0.7400


 24%|██▍       | 12/50 [00:05<00:15,  2.44it/s]

Epoch: 11, ood test loss: 0.6461	ood test accuracy: 0.6502




Epoch: 12, train loss: 17.5110	train accuracy: 0.7645




Epoch: 12, val loss: 0.5314	val accuracy: 0.7377


 26%|██▌       | 13/50 [00:05<00:15,  2.45it/s]

Epoch: 12, ood test loss: 0.6491	ood test accuracy: 0.6476




Epoch: 13, train loss: 17.4105	train accuracy: 0.7338




Epoch: 13, val loss: 0.5156	val accuracy: 0.7667


 28%|██▊       | 14/50 [00:06<00:14,  2.46it/s]

Epoch: 13, ood test loss: 0.6583	ood test accuracy: 0.6424




Epoch: 14, train loss: 17.2639	train accuracy: 0.7305




Epoch: 14, val loss: 0.5234	val accuracy: 0.7411


 30%|███       | 15/50 [00:06<00:14,  2.47it/s]

Epoch: 14, ood test loss: 0.6604	ood test accuracy: 0.6345




Epoch: 15, train loss: 17.1068	train accuracy: 0.7775




Epoch: 15, val loss: 0.5126	val accuracy: 0.7511


 32%|███▏      | 16/50 [00:06<00:13,  2.48it/s]

Epoch: 15, ood test loss: 0.6698	ood test accuracy: 0.5790




Epoch: 16, train loss: 17.0111	train accuracy: 0.7706




Epoch: 16, val loss: 0.5079	val accuracy: 0.7533


 34%|███▍      | 17/50 [00:07<00:13,  2.49it/s]

Epoch: 16, ood test loss: 0.6721	ood test accuracy: 0.6345




Epoch: 17, train loss: 16.8910	train accuracy: 0.7600




Epoch: 17, val loss: 0.5162	val accuracy: 0.7556


 36%|███▌      | 18/50 [00:07<00:12,  2.49it/s]

Epoch: 17, ood test loss: 0.6752	ood test accuracy: 0.6345




Epoch: 18, train loss: 16.7757	train accuracy: 0.7626




Epoch: 18, val loss: 0.5023	val accuracy: 0.7589


 38%|███▊      | 19/50 [00:08<00:12,  2.49it/s]

Epoch: 18, ood test loss: 0.6699	ood test accuracy: 0.6424




Epoch: 19, train loss: 16.6573	train accuracy: 0.7644




Epoch: 19, val loss: 0.4882	val accuracy: 0.7522


 40%|████      | 20/50 [00:08<00:11,  2.50it/s]

Epoch: 19, ood test loss: 0.6509	ood test accuracy: 0.6450




Epoch: 20, train loss: 16.5521	train accuracy: 0.7592




Epoch: 20, val loss: 0.5041	val accuracy: 0.7645


 42%|████▏     | 21/50 [00:08<00:11,  2.49it/s]

Epoch: 20, ood test loss: 0.6567	ood test accuracy: 0.6528




Epoch: 21, train loss: 16.4353	train accuracy: 0.7860




Epoch: 21, val loss: 0.5096	val accuracy: 0.7333


 44%|████▍     | 22/50 [00:09<00:11,  2.50it/s]

Epoch: 21, ood test loss: 0.6708	ood test accuracy: 0.6345




Epoch: 22, train loss: 16.3297	train accuracy: 0.7922




Epoch: 22, val loss: 0.4964	val accuracy: 0.7589


 46%|████▌     | 23/50 [00:09<00:10,  2.52it/s]

Epoch: 22, ood test loss: 0.6809	ood test accuracy: 0.6372




Epoch: 23, train loss: 16.2200	train accuracy: 0.7651




Epoch: 23, val loss: 0.4960	val accuracy: 0.7489


 48%|████▊     | 24/50 [00:10<00:10,  2.50it/s]

Epoch: 23, ood test loss: 0.6739	ood test accuracy: 0.6554




Epoch: 24, train loss: 16.1554	train accuracy: 0.7617




Epoch: 24, val loss: 0.4924	val accuracy: 0.7567


 50%|█████     | 25/50 [00:10<00:10,  2.47it/s]

Epoch: 24, ood test loss: 0.6455	ood test accuracy: 0.6554




Epoch: 25, train loss: 16.0155	train accuracy: 0.7729




Epoch: 25, val loss: 0.4852	val accuracy: 0.7522


 52%|█████▏    | 26/50 [00:10<00:09,  2.46it/s]

Epoch: 25, ood test loss: 0.6264	ood test accuracy: 0.6580




Epoch: 26, train loss: 15.8923	train accuracy: 0.7686




Epoch: 26, val loss: 0.4887	val accuracy: 0.7355


 54%|█████▍    | 27/50 [00:11<00:09,  2.42it/s]

Epoch: 26, ood test loss: 0.6517	ood test accuracy: 0.6476




Epoch: 27, train loss: 15.8384	train accuracy: 0.7748




Epoch: 27, val loss: 0.5001	val accuracy: 0.7500


 56%|█████▌    | 28/50 [00:11<00:08,  2.46it/s]

Epoch: 27, ood test loss: 0.6719	ood test accuracy: 0.6293




Epoch: 28, train loss: 15.7314	train accuracy: 0.7643




Epoch: 28, val loss: 0.5076	val accuracy: 0.7400


 58%|█████▊    | 29/50 [00:12<00:08,  2.48it/s]

Epoch: 28, ood test loss: 0.6694	ood test accuracy: 0.6267




Epoch: 29, train loss: 15.6352	train accuracy: 0.7758




Epoch: 29, val loss: 0.4823	val accuracy: 0.7444


 60%|██████    | 30/50 [00:12<00:08,  2.47it/s]

Epoch: 29, ood test loss: 0.6559	ood test accuracy: 0.6502




Epoch: 30, train loss: 15.5431	train accuracy: 0.7574




Epoch: 30, val loss: 0.4807	val accuracy: 0.7489


 62%|██████▏   | 31/50 [00:12<00:07,  2.46it/s]

Epoch: 30, ood test loss: 0.6452	ood test accuracy: 0.6554




Epoch: 31, train loss: 15.4224	train accuracy: 0.7810




Epoch: 31, val loss: 0.4854	val accuracy: 0.7578


 64%|██████▍   | 32/50 [00:13<00:07,  2.45it/s]

Epoch: 31, ood test loss: 0.6431	ood test accuracy: 0.6502




Epoch: 32, train loss: 15.3429	train accuracy: 0.7729




Epoch: 32, val loss: 0.4796	val accuracy: 0.7533


 66%|██████▌   | 33/50 [00:13<00:06,  2.47it/s]

Epoch: 32, ood test loss: 0.6542	ood test accuracy: 0.6476




Epoch: 33, train loss: 15.2803	train accuracy: 0.7913




Epoch: 33, val loss: 0.4868	val accuracy: 0.7567


 68%|██████▊   | 34/50 [00:14<00:06,  2.49it/s]

Epoch: 33, ood test loss: 0.6658	ood test accuracy: 0.6424




Epoch: 34, train loss: 15.1729	train accuracy: 0.7574




Epoch: 34, val loss: 0.4915	val accuracy: 0.7489


 70%|███████   | 35/50 [00:14<00:05,  2.52it/s]

Epoch: 34, ood test loss: 0.6717	ood test accuracy: 0.5738




Epoch: 35, train loss: 15.1075	train accuracy: 0.7773




Epoch: 35, val loss: 0.4788	val accuracy: 0.7578


 72%|███████▏  | 36/50 [00:14<00:05,  2.50it/s]

Epoch: 35, ood test loss: 0.6683	ood test accuracy: 0.6424




Epoch: 36, train loss: 15.0045	train accuracy: 0.7730




Epoch: 36, val loss: 0.4850	val accuracy: 0.7533


 74%|███████▍  | 37/50 [00:15<00:05,  2.49it/s]

Epoch: 36, ood test loss: 0.6641	ood test accuracy: 0.6606




Epoch: 37, train loss: 14.9350	train accuracy: 0.7799




Epoch: 37, val loss: 0.4862	val accuracy: 0.7455


 76%|███████▌  | 38/50 [00:15<00:04,  2.50it/s]

Epoch: 37, ood test loss: 0.6646	ood test accuracy: 0.6736




Epoch: 38, train loss: 14.8439	train accuracy: 0.7652




Epoch: 38, val loss: 0.4829	val accuracy: 0.7600


 78%|███████▊  | 39/50 [00:16<00:04,  2.47it/s]

Epoch: 38, ood test loss: 0.6441	ood test accuracy: 0.6528




Epoch: 39, train loss: 14.7713	train accuracy: 0.7738




Epoch: 39, val loss: 0.4877	val accuracy: 0.7478


 80%|████████  | 40/50 [00:16<00:04,  2.49it/s]

Epoch: 39, ood test loss: 0.6468	ood test accuracy: 0.6476




Epoch: 40, train loss: 14.6663	train accuracy: 0.7958




Epoch: 40, val loss: 0.4840	val accuracy: 0.7667


 82%|████████▏ | 41/50 [00:16<00:03,  2.47it/s]

Epoch: 40, ood test loss: 0.6569	ood test accuracy: 0.6528




Epoch: 41, train loss: 14.5931	train accuracy: 0.7966




Epoch: 41, val loss: 0.4821	val accuracy: 0.7500


 84%|████████▍ | 42/50 [00:17<00:03,  2.48it/s]

Epoch: 41, ood test loss: 0.6513	ood test accuracy: 0.6606




Epoch: 42, train loss: 14.4981	train accuracy: 0.7818




Epoch: 42, val loss: 0.4898	val accuracy: 0.7388


 86%|████████▌ | 43/50 [00:17<00:02,  2.50it/s]

Epoch: 42, ood test loss: 0.6535	ood test accuracy: 0.6736




Epoch: 43, train loss: 14.4388	train accuracy: 0.7854




Epoch: 43, val loss: 0.4913	val accuracy: 0.7645


 88%|████████▊ | 44/50 [00:18<00:02,  2.51it/s]

Epoch: 43, ood test loss: 0.6636	ood test accuracy: 0.6632




Epoch: 44, train loss: 14.3794	train accuracy: 0.7728




Epoch: 44, val loss: 0.4877	val accuracy: 0.7746


 90%|█████████ | 45/50 [00:18<00:02,  2.49it/s]

Epoch: 44, ood test loss: 0.6409	ood test accuracy: 0.6606




Epoch: 45, train loss: 14.2791	train accuracy: 0.7965




Epoch: 45, val loss: 0.4850	val accuracy: 0.7768


 92%|█████████▏| 46/50 [00:18<00:01,  2.51it/s]

Epoch: 45, ood test loss: 0.6532	ood test accuracy: 0.6788




Epoch: 46, train loss: 14.2067	train accuracy: 0.7800




Epoch: 46, val loss: 0.4826	val accuracy: 0.7667


 94%|█████████▍| 47/50 [00:19<00:01,  2.48it/s]

Epoch: 46, ood test loss: 0.6636	ood test accuracy: 0.6554




Epoch: 47, train loss: 14.1451	train accuracy: 0.8025




Epoch: 47, val loss: 0.4795	val accuracy: 0.7656


 96%|█████████▌| 48/50 [00:19<00:00,  2.48it/s]

Epoch: 47, ood test loss: 0.6467	ood test accuracy: 0.6580




Epoch: 48, train loss: 14.0513	train accuracy: 0.8019




Epoch: 48, val loss: 0.4754	val accuracy: 0.7835


 98%|█████████▊| 49/50 [00:20<00:00,  2.48it/s]

Epoch: 48, ood test loss: 0.6566	ood test accuracy: 0.6632




Epoch: 49, train loss: 13.9764	train accuracy: 0.7982




Epoch: 49, val loss: 0.4781	val accuracy: 0.7567


100%|██████████| 50/50 [00:20<00:00,  2.44it/s]

Epoch: 49, ood test loss: 0.6554	ood test accuracy: 0.6580





### Training of the maximum i.i.d. disagreement rate distribution

We now have a base Bayesian model fitted to the training data. We now seek the disagreement rate with respect to our base classifier of models that:

- agree with our base classifier on the training data
- disagree with our base classifier on unseen i.i.d. data

In order to achieve this, we batch sample from our belief over the decision surface, and repeatedly select the decision surface with the strongest disagreement rate. We collect these disagreement rates into model.Phi.


In [9]:
monitor.pretrain_disagreement_distribution(dataset=dataset_dict['d3m_train'],
                                           n_post_samples=args.d3m.n_post_samples,
                                           data_sample_size=args.d3m.data_sample_size,
                                           Phi_size=args.d3m.Phi_size, 
                                           temperature=args.d3m.temp,
                                           )

100%|██████████| 1000/1000 [00:03<00:00, 274.17it/s]


### Compute FPRs and TPRs

We are essentially done. Our monitor has the essential components:

- a trained base classifier on i.i.d. data
- a trained distribution of i.i.d. disagreement rates

This base classifier is now ready to be deployed on any deployment data, as long as we monitor periodically by running either ``monitor.d3m_test`` or ``monitor.repeat_tests`` (for repeated testing, useful to compute statistics) on future data.

In [10]:
stats = {}
dis_rates = {}

for k,dataset in {
    'd3m_train': dataset_dict['d3m_train'],
    'd3m_id': dataset_dict['d3m_id'],
    'd3m_ood': dataset_dict['d3m_ood']
}.items():
    rate, max_dis_rates = monitor.repeat_tests(n_repeats=args.d3m.n_repeats,
                                    dataset=dataset, 
                                    n_post_samples=args.d3m.n_post_samples,
                                    data_sample_size=args.d3m.data_sample_size,
                                    temperature=args.d3m.temp
                                    )
    print(f"{k}: {rate}")
    stats[k] = rate
    dis_rates[k] = (np.mean(max_dis_rates), np.std(max_dis_rates))


100%|██████████| 100/100 [00:00<00:00, 272.35it/s]


d3m_train: 0.02


100%|██████████| 100/100 [00:00<00:00, 275.26it/s]


d3m_id: 0.03


100%|██████████| 100/100 [00:00<00:00, 274.74it/s]

d3m_ood: 1.0





In the above, we notice that the TPR for OOD samples is 1.0, i.e. our bayesian D3M monitor is able to correctly identify deteriorating changes in the data distribution.

We further notice that for a held-out in-distribution sample, the model does not identify the sample as out-of-distribution. Indeed, the base classifier achieves similar performance on this held-out set than on the validation set, proving that bayesian D3M incurs low false positive rates for in-distribution samples.