# Hyperparameter Ensembles for Robustness and Uncertainty Quantification

## pip install을 사용하여 깃허브에 있는 코드를 클론합니다

In [1]:
!pip install "git+https://github.com/google/uncertainty-baselines.git#egg=uncertainty_baselines"

In [2]:
!pip install "git+https://github.com/google-research/robustness_metrics.git#egg=robustness_metrics"

In [3]:
!pip install "git+https://github.com/google/edward2"

## 필요한 라이브러리를 import하고 구글 드라이브를 마운트 합니다

In [4]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
 
import uncertainty_baselines as ub

# 구글 드라이브 마운트
from google.colab import drive

drive.mount('/content/gdrive')

## 앙상블 예측의 정확도를 계산하는 함수, cross_entropy 계산하는 함수, greedy_selection 알고리즘, 체크포인트의 디렉토리를 구분하는 함수를 선언합니다

* greedy_selection의 최적화되어야 하는 목표는 nll, acc 또는 nll-acc가 될수 있습니다.

In [None]:
 def _ensemble_accuracy(labels, logits_list):
  per_probs = tf.nn.softmax(logits_list)
  probs = tf.reduce_mean(per_probs, axis=0)
  acc = tf.keras.metrics.SparseCategoricalAccuracy()
  acc.update_state(labels, probs)
  return acc.result()
 
def _ensemble_cross_entropy(labels, logits):
  logits = tf.convert_to_tensor(logits)
  ensemble_size = float(logits.shape[0])
  labels = tf.cast(labels, tf.int32)
  ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=tf.broadcast_to(labels[tf.newaxis, ...], tf.shape(logits)[:-1]),
      logits=logits)
  nll = -tf.reduce_logsumexp(-ce, axis=0) + tf.math.log(ensemble_size)
  return tf.reduce_mean(nll)
 
 
def greedy_selection(val_logits, val_labels, max_ens_size, objective='nll'): 
  assert_msg = 'Unknown objective type (received {}).'.format(objective)
  assert objective in ('nll', 'acc', 'nll-acc'), assert_msg
  if objective == 'nll':
    get_objective = lambda acc, nll: nll
  elif objective == 'acc':
    get_objective = lambda acc, nll: acc
  else:
    get_objective = lambda acc, nll: nll-acc
 
  best_acc = 0.
  best_nll = np.inf
  best_objective = np.inf
  ens = []
 
  def get_ens_size():
    return len(set(ens))
 
  while get_ens_size() < max_ens_size:
    current_val_logits = [val_logits[model_id] for model_id in ens]
    best_model_id = None
    for model_id, logits in enumerate(val_logits):
      acc = _ensemble_accuracy(val_labels, current_val_logits + [logits])
      nll = _ensemble_cross_entropy(val_labels, current_val_logits + [logits])
      obj = get_objective(acc, nll)
      if obj < best_objective:
        best_acc = acc
        best_nll = nll
        best_objective = obj
        best_model_id = model_id
    if best_model_id is None:
      print('Ensemble could not be improved: Greedy selection stops.')
      break
    ens.append(best_model_id)
  return ens, best_acc, best_nll
 
 
def parse_checkpoint_dir(checkpoint_dir):
  paths = []
  subdirectories = tf.io.gfile.glob(os.path.join(checkpoint_dir, '*'))
  is_checkpoint = lambda f: ('checkpoint' in f and '.index' in f)
  for subdir in subdirectories:
    for path, _, files in tf.io.gfile.walk(subdir):
      if any(f for f in files if is_checkpoint(f)):
        latest_checkpoint_without_suffix = tf.train.latest_checkpoint(path)
        paths.append(os.path.join(path, latest_checkpoint_without_suffix))
        break
  return paths

## dataset의 종류와 훈련 비율, 베치 사이즈, 앙상블 사이즈를 선언합니다

In [None]:
DATASET = 'cifar10'
TRAIN_PROPORTION = 0.95
BATCH_SIZE = 64
ENSEMBLE_SIZE = 4

## Cifar 데이터셋을 로드하는 함수를 선언합니다

In [None]:
from typing import Any, Dict, Optional, Union

from robustness_metrics.common import types
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from uncertainty_baselines.datasets import augment_utils
from uncertainty_baselines.datasets import augmix
from uncertainty_baselines.datasets import base

CIFAR10_MEAN = tf.constant([0.4914, 0.4822, 0.4465])
CIFAR10_STD = tf.constant([0.2470, 0.2435, 0.2616])

def _tuple_dict_fn_converter(fn, *args):

  def dict_fn(batch_dict):
    images, labels = fn(*args, batch_dict['features'], batch_dict['labels'])
    return {'features': images, 'labels': labels}

  return dict_fn


class _CifarDataset(base.BaseDataset):
  def __init__(
      self,
      name: str,
      fingerprint_key: str,
      split: str,
      seed: Optional[Union[int, tf.Tensor]] = None,
      validation_percent: float = 0.0,
      shuffle_buffer_size: Optional[int] = None,
      num_parallel_parser_calls: int = 64,
      drop_remainder: bool = True,
      normalize: bool = True,
      try_gcs: bool = False,
      download_data: bool = False,
      use_bfloat16: bool = False,
      aug_params: Optional[Dict[str, Any]] = None,
      data_dir: Optional[str] = None,
      is_training: Optional[bool] = None,
      **unused_kwargs: Dict[str, Any]):
    
    self._normalize = normalize
    dataset_builder = tfds.builder(
        name, 
        data_dir=data_dir)
    #try_gcs=try_gcs,
    if is_training is None:
      is_training = split in ['train', tfds.Split.TRAIN]
    new_split = base.get_validation_percent_split(
        dataset_builder, validation_percent, split)
    super(_CifarDataset, self).__init__(
        name=name,
        dataset_builder=dataset_builder,
        split=new_split,
        seed=seed,
        is_training=is_training,
        shuffle_buffer_size=shuffle_buffer_size,
        num_parallel_parser_calls=num_parallel_parser_calls,
        drop_remainder=drop_remainder,
        fingerprint_key=fingerprint_key,
        download_data=download_data,
        cache=True)

    self._use_bfloat16 = use_bfloat16
    if aug_params is None:
      aug_params = {}
    self._adaptive_mixup = aug_params.get('adaptive_mixup', False)
    ensemble_size = aug_params.get('ensemble_size', 1)
    if self._adaptive_mixup and 'mixup_coeff' not in aug_params:
      aug_params['mixup_coeff'] = tf.ones([ensemble_size, 10])
    self._aug_params = aug_params

  def _create_process_example_fn(self) -> base.PreProcessFn:

    def _example_parser(example: types.Features) -> types.Features:
      """A pre-process function to return images in [0, 1]."""
      image = example['image']
      image_dtype = tf.bfloat16 if self._use_bfloat16 else tf.float32
      use_augmix = self._aug_params.get('augmix', False)
      if self._is_training:
        image_shape = tf.shape(image)
        image = tf.image.resize_with_crop_or_pad(
            image, image_shape[0] + 4, image_shape[1] + 4)
        per_example_step_seed = tf.random.experimental.stateless_fold_in(
            self._seed, example[self._enumerate_id_key])
        per_example_step_seeds = tf.random.experimental.stateless_split(
            per_example_step_seed, num=4)
        image = tf.image.stateless_random_crop(
            image,
            (image_shape[0], image_shape[0], 3),
            seed=per_example_step_seeds[0])
        image = tf.image.stateless_random_flip_left_right(
            image,
            seed=per_example_step_seeds[1])

        if self._aug_params.get('random_augment', False):
          count = self._aug_params['aug_count']
          augment_seeds = tf.random.experimental.stateless_split(
              per_example_step_seeds[2], num=count)
          augmenter = augment_utils.RandAugment()
          augmented = [
              augmenter.distort(image, seed=augment_seeds[c])
              for c in range(count)
          ]
          image = tf.stack(augmented)

        if use_augmix:
          augmenter = augment_utils.RandAugment()
          image = augmix.do_augmix(
              image, self._aug_params, augmenter, image_dtype,
              mean=CIFAR10_MEAN, std=CIFAR10_STD,
              seed=per_example_step_seeds[3])

      if not use_augmix:
        if self._normalize:
          image = augmix.normalize_convert_image(
              image, image_dtype, mean=CIFAR10_MEAN, std=CIFAR10_STD)
        else:
          image = tf.image.convert_image_dtype(image, image_dtype)
      parsed_example = example.copy()
      parsed_example['features'] = image

      mixup_alpha = self._aug_params.get('mixup_alpha', 0)
      label_smoothing = self._aug_params.get('label_smoothing', 0.)
      should_onehot = mixup_alpha > 0 or label_smoothing > 0
      if should_onehot:
        parsed_example['labels'] = tf.one_hot(
            example['label'], 10, dtype=tf.float32)
      else:
        parsed_example['labels'] = tf.cast(example['label'], tf.float32)

      del parsed_example['image']
      del parsed_example['label']
      return parsed_example

    return _example_parser

  def _create_process_batch_fn(
      self,
      batch_size: int) -> Optional[base.PreProcessFn]:
    if self._is_training and self._aug_params.get('mixup_alpha', 0) > 0:
      if self._adaptive_mixup:
        return _tuple_dict_fn_converter(
            augmix.adaptive_mixup, batch_size, self._aug_params)
      else:
        return _tuple_dict_fn_converter(
            augmix.mixup, batch_size, self._aug_params)
    return None


class Cifar10Dataset(_CifarDataset):
  """CIFAR10 dataset builder class."""

  def __init__(self, **kwargs):
    super(Cifar10Dataset, self).__init__(
        name='cifar10',
        fingerprint_key='id',
        **kwargs)


class Cifar100Dataset(_CifarDataset):
  """CIFAR100 dataset builder class."""

  def __init__(self, **kwargs):
    super(Cifar100Dataset, self).__init__(
        name='cifar100',
        fingerprint_key='id',
        **kwargs)


class Cifar10CorruptedDataset(_CifarDataset):
  """CIFAR10-C dataset builder class."""

  def __init__(
      self,
      corruption_type: str,
      severity: int,
      **kwargs):
    """Create a CIFAR10-C tf.data.Dataset builder.
    Args:
      corruption_type: Corruption name.
      severity: Corruption severity, an integer between 1 and 5.
      **kwargs: Additional keyword arguments.
    """
    super(Cifar10CorruptedDataset, self).__init__(
        name=f'cifar10_corrupted/{corruption_type}_{severity}',
        fingerprint_key=None,
        **kwargs)

## datasets의 utilities를 사용할 수 있는 함수를 로드합니다

In [None]:
import json
import logging
from typing import Any, List, Tuple, Union
import warnings

import tensorflow as tf
import tensorflow_datasets as tfds
from uncertainty_baselines.datasets.base import BaseDataset
from uncertainty_baselines.datasets.cifar import Cifar100Dataset
from uncertainty_baselines.datasets.cifar import Cifar10CorruptedDataset
#from uncertainty_baselines.datasets.cifar import Cifar10Dataset
from uncertainty_baselines.datasets.cifar100_corrupted import Cifar100CorruptedDataset
from uncertainty_baselines.datasets.clinc_intent import ClincIntentDetectionDataset
from uncertainty_baselines.datasets.criteo import CriteoDataset
from uncertainty_baselines.datasets.diabetic_retinopathy_detection import DiabeticRetinopathyDetectionDataset
from uncertainty_baselines.datasets.genomics_ood import GenomicsOodDataset
from uncertainty_baselines.datasets.glue import GlueDatasets
from uncertainty_baselines.datasets.imagenet import ImageNetDataset
from uncertainty_baselines.datasets.mnist import MnistDataset
from uncertainty_baselines.datasets.mnli import MnliDataset
from uncertainty_baselines.datasets.movielens import MovieLensDataset
from uncertainty_baselines.datasets.places import Places365Dataset
from uncertainty_baselines.datasets.random import RandomGaussianImageDataset
from uncertainty_baselines.datasets.random import RandomRademacherImageDataset
from uncertainty_baselines.datasets.svhn import SvhnDataset
from uncertainty_baselines.datasets.toxic_comments import CivilCommentsDataset
from uncertainty_baselines.datasets.toxic_comments import CivilCommentsIdentitiesDataset
from uncertainty_baselines.datasets.toxic_comments import WikipediaToxicityDataset

try:
  from uncertainty_baselines.datasets.speech_commands import SpeechCommandsDataset  # pylint: disable=g-import-not-at-top
except ImportError as e:
  warnings.warn(f'Skipped due to ImportError: {e}')
  SpeechCommandsDataset = None

DATASETS = {
    'cifar100': Cifar100Dataset,
    'cifar10': Cifar10Dataset,
    'cifar10_corrupted': Cifar10CorruptedDataset,
    'cifar100_corrupted': Cifar100CorruptedDataset,
    'civil_comments': CivilCommentsDataset,
    'civil_comments_identities': CivilCommentsIdentitiesDataset,
    'clinic_intent': ClincIntentDetectionDataset,
    'criteo': CriteoDataset,
    'diabetic_retinopathy_detection': DiabeticRetinopathyDetectionDataset,
    'imagenet': ImageNetDataset,
    'mnist': MnistDataset,
    'mnli': MnliDataset,
    'movielens': MovieLensDataset,
    'places365': Places365Dataset,
    'random_gaussian': RandomGaussianImageDataset,
    'random_rademacher': RandomRademacherImageDataset,
    'speech_commands': SpeechCommandsDataset,
    'svhn_cropped': SvhnDataset,
    'glue/cola': GlueDatasets['glue/cola'],
    'glue/sst2': GlueDatasets['glue/sst2'],
    'glue/mrpc': GlueDatasets['glue/mrpc'],
    'glue/qqp': GlueDatasets['glue/qqp'],
    'glue/qnli': GlueDatasets['glue/qnli'],
    'glue/rte': GlueDatasets['glue/rte'],
    'glue/wnli': GlueDatasets['glue/wnli'],
    'glue/stsb': GlueDatasets['glue/stsb'],
    'wikipedia_toxicity': WikipediaToxicityDataset,
    'genomics_ood': GenomicsOodDataset,
}


def get_dataset_names() -> List[str]:
  return list(DATASETS.keys())


def get(
    dataset_name: str,
    split: Union[Tuple[str, float], str, tfds.Split],
    **hyperparameters: Any) -> BaseDataset:
  hyperparameters_py = {
      k: (v.numpy().tolist() if isinstance(v, tf.Tensor) else v)
      for k, v in hyperparameters.items()
  }
  logging.info(
      'Building dataset %s with additional kwargs:\n%s',
      dataset_name,
      json.dumps(hyperparameters_py, indent=2, sort_keys=True))
  if dataset_name not in DATASETS:
    raise ValueError('Unrecognized dataset name: {!r}'.format(dataset_name))

  dataset_class = DATASETS[dataset_name]
  return dataset_class(
      split=split,
      **hyperparameters)

## 텐서플로우 데이터셋 빌더 함수를 통해 cifar10 함수를 다운받고 준비시킵니다

In [None]:
dataset_builder = tfds.builder('cifar10')
dataset_builder.download_and_prepare()

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteC9X90H/cifar10-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteC9X90H/cifar10-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


## cifar10 데이터 셋을 test 셋과 validation 셋으로 나누어 로드합니다

In [None]:
# load data
ds_info = tfds.builder(DATASET).info
num_classes = ds_info.features['label'].num_classes
# test set
steps_per_eval = ds_info.splits['test'].num_examples // BATCH_SIZE 
test_dataset=Cifar10Dataset(split=tfds.Split.TEST).load(batch_size=BATCH_SIZE)
# validation set
validation_percent = 1 - TRAIN_PROPORTION
val_dataset = Cifar10Dataset(split=tfds.Split.VALIDATION,
    validation_percent=validation_percent,
    drop_remainder=False).load(batch_size=BATCH_SIZE)
steps_per_val_eval = int(ds_info.splits['train'].num_examples *
                          validation_percent) // BATCH_SIZE  # validation set

## 모델 wide_renet을 생성하는 함수를 선언합니다

In [None]:
import functools
from typing import Any, Dict, Iterable, Optional

import tensorflow as tf

HP_KEYS = ('bn_l2', 'input_conv_l2', 'group_1_conv_l2', 'group_2_conv_l2',
           'group_3_conv_l2', 'dense_kernel_l2', 'dense_bias_l2')

BatchNormalization = functools.partial(  # pylint: disable=invalid-name
    tf.keras.layers.BatchNormalization,
    epsilon=1e-5,  
    momentum=0.9)


def Conv2D(filters, seed=None, **kwargs):  # pylint: disable=invalid-name
  default_kwargs = {
      'kernel_size': 3,
      'padding': 'same',
      'use_bias': False,
      'kernel_initializer': tf.keras.initializers.HeNormal(seed=seed),
  }
  default_kwargs.update(kwargs)
  return tf.keras.layers.Conv2D(filters, **default_kwargs)


def basic_block(
    inputs: tf.Tensor,
    filters: int,
    strides: int,
    conv_l2: float,
    bn_l2: float,
    seed: int,
    version: int) -> tf.Tensor:
  x = inputs
  y = inputs
  if version == 2:
    y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
                           gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
    y = tf.keras.layers.Activation('relu')(y)
  seeds = tf.random.experimental.stateless_split([seed, seed + 1], 3)[:, 0]
  y = Conv2D(filters,
             strides=strides,
             seed=seeds[0],
             kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
  y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
                         gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
  y = tf.keras.layers.Activation('relu')(y)
  y = Conv2D(filters,
             strides=1,
             seed=seeds[1],
             kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(y)
  if version == 1:
    y = BatchNormalization(beta_regularizer=tf.keras.regularizers.l2(bn_l2),
                           gamma_regularizer=tf.keras.regularizers.l2(bn_l2))(y)
  if not x.shape.is_compatible_with(y.shape):
    x = Conv2D(filters,
               kernel_size=1,
               strides=strides,
               seed=seeds[2],
               kernel_regularizer=tf.keras.regularizers.l2(conv_l2))(x)
  x = tf.keras.layers.add([x, y])
  if version == 1:
    x = tf.keras.layers.Activation('relu')(x)
  return x


def group(inputs, filters, strides, num_blocks, conv_l2, bn_l2, version, seed):
  seeds = tf.random.experimental.stateless_split(
      [seed, seed + 1], num_blocks)[:, 0]
  x = basic_block(
      inputs,
      filters=filters,
      strides=strides,
      conv_l2=conv_l2,
      bn_l2=bn_l2,
      version=version,
      seed=seeds[0])
  for i in range(num_blocks - 1):
    x = basic_block(
        x,
        filters=filters,
        strides=1,
        conv_l2=conv_l2,
        bn_l2=bn_l2,
        version=version,
        seed=seeds[i + 1])
  return x


def _parse_hyperparameters(l2: float, hps: Dict[str, float]):

  assert_msg = ('Ambiguous hyperparameter specifications: either l2 or hps '
                'must be provided (received {} and {}).'.format(l2, hps))
  is_specified = lambda h: bool(h) and all(v is not None for v in h.values())
  only_l2_is_specified = l2 is not None and not is_specified(hps)
  only_hps_is_specified = l2 is None and is_specified(hps)
  assert only_l2_is_specified or only_hps_is_specified, assert_msg
  if only_hps_is_specified:
    assert_msg = 'hps must contain the keys {}!={}.'.format(HP_KEYS, hps.keys())
    assert set(hps.keys()).issuperset(HP_KEYS), assert_msg
    return hps
  else:
    return {k: l2 for k in HP_KEYS}


def wide_resnet(
    input_shape: Iterable[int],
    depth: int,
    width_multiplier: int,
    num_classes: int,
    l2: float,
    version: int = 2,
    seed: int = 42,
    hps: Optional[Dict[str, float]] = None) -> tf.keras.models.Model:
  l2_reg = tf.keras.regularizers.l2
  hps = _parse_hyperparameters(l2, hps)

  seeds = tf.random.experimental.stateless_split([seed, seed + 1], 5)[:, 0]
  if (depth - 4) % 6 != 0:
    raise ValueError('depth should be 6n+4 (e.g., 16, 22, 28, 40).')
  num_blocks = (depth - 4) // 6
  inputs = tf.keras.layers.Input(shape=input_shape)
  x = Conv2D(16,
             strides=1,
             seed=seeds[0],
             kernel_regularizer=l2_reg(hps['input_conv_l2']))(inputs)
  if version == 1:
    x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
                           gamma_regularizer=l2_reg(hps['bn_l2']))(x)
    x = tf.keras.layers.Activation('relu')(x)
  x = group(x,
            filters=16 * width_multiplier,
            strides=1,
            num_blocks=num_blocks,
            conv_l2=hps['group_1_conv_l2'],
            bn_l2=hps['bn_l2'],
            version=version,
            seed=seeds[1])
  x = group(x,
            filters=32 * width_multiplier,
            strides=2,
            num_blocks=num_blocks,
            conv_l2=hps['group_2_conv_l2'],
            bn_l2=hps['bn_l2'],
            version=version,
            seed=seeds[2])
  x = group(x,
            filters=64 * width_multiplier,
            strides=2,
            num_blocks=num_blocks,
            conv_l2=hps['group_3_conv_l2'],
            bn_l2=hps['bn_l2'],
            version=version,
            seed=seeds[3])
  if version == 2:
    x = BatchNormalization(beta_regularizer=l2_reg(hps['bn_l2']),
                           gamma_regularizer=l2_reg(hps['bn_l2']))(x)
    x = tf.keras.layers.Activation('relu')(x)
  x = tf.keras.layers.AveragePooling2D(pool_size=8)(x)
  x = tf.keras.layers.Flatten()(x)
  x = tf.keras.layers.Dense(
      num_classes,
      kernel_initializer=tf.keras.initializers.HeNormal(seed=seeds[4]),
      kernel_regularizer=l2_reg(hps['dense_kernel_l2']),
      bias_regularizer=l2_reg(hps['dense_bias_l2']))(x)
  return tf.keras.Model(
      inputs=inputs,
      outputs=x,
      name='wide_resnet-{}-{}'.format(depth, width_multiplier))


def create_model(
    batch_size: Optional[int],
    depth: int,
    width_multiplier: int,
    input_shape: Iterable[int] = (32, 32, 3),
    num_classes: int = 10,
    l2_weight: float = 0.0,
    version: int = 2,
    **unused_kwargs: Dict[str, Any]) -> tf.keras.models.Model:
  del batch_size  # unused arg
  return wide_resnet(input_shape=input_shape,
                     depth=depth,
                     width_multiplier=width_multiplier,
                     num_classes=num_classes,
                     l2=l2_weight,
                     version=version)

## wide_resnet 모델을 생성합니다

In [None]:
def _extract_hyperparameter_dictionary():
  """Create the dictionary of hyperparameters from FLAGS."""
  hp_keys = ('bn_l2', 'input_conv_l2', 'group_1_conv_l2', 'group_2_conv_l2',
           'group_3_conv_l2', 'dense_kernel_l2', 'dense_bias_l2')
  hps = {'bn_l2':None, 'input_conv_l2':None, 'group_1_conv_l2':None, 'group_2_conv_l2':None,
           'group_3_conv_l2':None, 'dense_kernel_l2':None, 'dense_bias_l2':None}
  return hps  

model = wide_resnet(
    input_shape=(32, 32, 3),
    depth=28,
    width_multiplier=10,
    num_classes=num_classes,
    l2=2e-4,
    hps=_extract_hyperparameter_dictionary())

## saveCheckpoits.py 파일에서 저장된 checkpoints 파일을 로드합니다

In [5]:
 # Load checkpoints
import os
CHECKPOINT_DIR = '/content/gdrive/My Drive/tmp/'
ensemble_filenames = parse_checkpoint_dir(CHECKPOINT_DIR)

model_pool_size = len(ensemble_filenames)
checkpoint = tf.train.Checkpoint(model=model)
print('Model pool size: {}'.format(model_pool_size))

## validation set 에서의 앙상블 logits을 계산합니다

In [8]:
val_logits, val_labels = [], []
for m, ensemble_filename in enumerate(ensemble_filenames):
  tf.keras.backend.clear_session()
  checkpoint.restore(ensemble_filename)
  val_iterator = iter(val_dataset)
  val_logits_m = []
  for _ in range(steps_per_val_eval):
    inputs = next(val_iterator)
    features = inputs['features']
    labels = inputs['labels']
    val_logits_m.append(model(features, training=False))
    if m == 0:
      val_labels.append(labels)
 
  val_logits.append(tf.concat(val_logits_m, axis=0))
  if m == 0:
    val_labels = tf.concat(val_labels, axis=0)
 
  
  if m % 10 == 0 or m == model_pool_size - 1:
    percent = (m + 1.) / model_pool_size
    message = ('{:.1%} completion for prediction on validation set: '
                'model {:d}/{:d}.'.format(percent, m + 1, model_pool_size))
    print(message)

## greedy member selection을 통해 validation set에서의 앙상블을 구축합니다

In [7]:
selected_members, val_acc, val_nll = greedy_selection(val_logits, val_labels,
                                                        ENSEMBLE_SIZE,
                                                        objective='nll')
unique_selected_members = list(set(selected_members))
message = ('Members selected by greedy procedure: model ids = {} (with {} unique '
            'member(s))').format(
                selected_members, len(unique_selected_members))
print(message)

## 평가지표 NLL, accuracy를 metrics로 선언합니다

In [9]:
 # Evaluate the following metrics on the test step
metrics = {
    'ensemble/negative_log_likelihood': tf.keras.metrics.Mean(),
    'ensemble/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
}
metrics_single = {
    'single/negative_log_likelihood': tf.keras.metrics.SparseCategoricalCrossentropy(),
    'single/accuracy': tf.keras.metrics.SparseCategoricalAccuracy(),
}

## test set 에서의 앙상블 logits을 계산합니다


In [None]:
logits_test = []
for m, member_id in enumerate(unique_selected_members):
  ensemble_filename = ensemble_filenames[member_id]
  checkpoint.restore(ensemble_filename)
  logits = []
  test_iterator = iter(test_dataset)
  for _ in range(steps_per_eval):
    features = next(test_iterator)['features']
    logits.append(model(features, training=False))
  logits_test.append(tf.concat(logits, axis=0))
logits_test = tf.convert_to_tensor(logits_test)
print('Completed computation of member logits on the test set.')

## test set에서의 평가를 실시합니다

* 비교를 위해 단일 모델과 앙상블 모델 두개의 결과를 출력합니다

In [10]:
 # compute test metrics
test_iterator = iter(test_dataset)
for step in range(steps_per_eval):
  labels = next(test_iterator)['labels']
  logits = logits_test[:, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]
  labels = tf.cast(labels, tf.int32)
  negative_log_likelihood = _ensemble_cross_entropy(labels, logits)
  per_probs = tf.nn.softmax(logits)
  probs = tf.reduce_mean(per_probs, axis=0)
  metrics['ensemble/negative_log_likelihood'].update_state(
      negative_log_likelihood)
  metrics['ensemble/accuracy'].update_state(labels, probs)

  logits_single = logits_test[0, (step*BATCH_SIZE):((step+1)*BATCH_SIZE)]
  probs_single = tf.nn.softmax(logits_single)
  metrics_single['single/negative_log_likelihood'].update_state(labels, logits_single)
  metrics_single['single/accuracy'].update_state(labels, probs_single)
 
  percent = (step + 1) / steps_per_eval
  if step % 25 == 0 or step == steps_per_eval - 1:
    message = ('{:.1%} completion final test prediction'.format(percent))
    print(message)
 
ensemble_results = {name: metric.result() for name, metric in metrics.items()}
single_results = {name: metric.result() for name, metric in metrics_single.items()}

# 앙상블 모델의 결과와 단일 모델의 결과를 출력합니다

In [11]:
 print('Ensemble performance:')
for m, val in ensemble_results.items():
  print('   {}: {}'.format(m, val))
 
print('\nFor comparison:')
for m, val in single_results.items():
  print('   {}: {}'.format(m, val))