# BiT - Big image Transfer - Example

### Imports

In [1]:
from functools import partial
import time
import os

import numpy as np
import input_pipeline_tf2_or_jax as input_pipeline
import bit_tf2.models as models
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()

import bit_common
import bit_hyperrule


### Load And Train Model via command Line


```sh
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
```



### Set configs

In [21]:
logdir = 'tmp/'
bit_pretrained_dir = '.'
dataset = 'cifar100'
name = 'test'
model = 'BiT-M-R50x1'
bit_model_file = 'BiT-M-R50x1.h5'
batch_eval = 32
batch = 128
batch_split = 1
tfds_manual_dir = None
examples_per_class = None
examples_per_class_seed = 0
base_lr= 0.001
eval_every = None


### Set Logger

In [22]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self
        
d = AttrDict()
d.logdir = logdir
d.name = name

logger = bit_common.setup_logger(d)
logger.info(f'Available devices: {tf.config.list_physical_devices()}')

strategy = tf.distribute.MirroredStrategy()
num_devices = strategy.num_replicas_in_sync
print('Number of devices: {}'.format(num_devices))

2020-05-24 15:54:20,837 [INFO] bit_common: {'logdir': 'tmp/', 'name': 'test'}
2020-05-24 15:54:20,838 [INFO] bit_common: Available devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU')]






INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


2020-05-24 15:54:20,840 [INFO] tensorflow: Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


Number of devices: 1


### Get dataset

In [23]:
dataset_info = input_pipeline.get_dataset_info(
    dataset, 'train', examples_per_class)
dataset_info

2020-05-24 15:54:22,022 [INFO] absl: Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar100/3.0.2
2020-05-24 15:54:24,844 [INFO] absl: Load dataset info from /var/folders/gd/3nb3fp150sv5clsry0gx9v6c0000gn/T/tmpo6k1ovratfds
2020-05-24 15:54:24,847 [INFO] absl: Field info.citation from disk and from code do not match. Keeping the one from code.


{'original_num_examples': 50000, 'num_examples': 50000, 'num_classes': 100}

In [24]:
resize_size, crop_size = bit_hyperrule.get_resolution_from_dataset(dataset)


In [25]:
data_train = input_pipeline.get_data(
    dataset=dataset, mode='train',
    repeats=None, batch_size=batch,
    resize_size=resize_size, crop_size=crop_size,
    examples_per_class=examples_per_class,
    examples_per_class_seed=examples_per_class_seed,
    mixup_alpha=bit_hyperrule.get_mixup(dataset_info['num_examples']),
    num_devices=num_devices,
    tfds_manual_dir=tfds_manual_dir)

2020-05-24 15:55:16,281 [INFO] absl: Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar100/3.0.2
2020-05-24 15:55:16,449 [INFO] absl: Load dataset info from /var/folders/gd/3nb3fp150sv5clsry0gx9v6c0000gn/T/tmp7l1ry_z4tfds
2020-05-24 15:55:16,452 [INFO] absl: Field info.citation from disk and from code do not match. Keeping the one from code.
2020-05-24 15:55:16,454 [INFO] absl: Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: cifar100/3.0.2
2020-05-24 15:55:16,620 [INFO] absl: Load dataset info from /var/folders/gd/3nb3fp150sv5clsry0gx9v6c0000gn/T/tmp0n7pk16etfds
2020-05-24 15:55:16,623 [INFO] absl: Field info.citation from disk and from code do not match. Keeping the one from code.
2020-05-24 15:55:16,624 [INFO] absl: Generating dataset cifar100 (/Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2)


[1mDownloading and preparing dataset cifar100/3.0.2 (download: 160.71 MiB, generated: 132.03 MiB, total: 292.74 MiB) to /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…

2020-05-24 15:55:17,211 [INFO] absl: Downloading https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz into /Users/ramine.tinati/tensorflow_datasets/downloads/cs.toronto.edu_kriz_cifar-100-binaryzK0jb7CkNxmV4pH2clu5WdAlIotsPlZhrMxx9-DELEk.tar.gz.tmp.fc40ccc548c840ada59e5441e409064f...
2020-05-24 15:56:07,839 [INFO] absl: Generating split train










HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2.incompleteAGYOCL/cifar100-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))

2020-05-24 15:56:35,765 [INFO] absl: Done writing /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2.incompleteAGYOCL/cifar100-train.tfrecord. Shard lengths: [50000]
2020-05-24 15:56:35,774 [INFO] absl: Generating split test


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2.incompleteAGYOCL/cifar100-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

2020-05-24 15:56:41,568 [INFO] absl: Done writing /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2.incompleteAGYOCL/cifar100-test.tfrecord. Shard lengths: [10000]
2020-05-24 15:56:41,572 [INFO] absl: Skipping computing stats for mode ComputeStatsMode.AUTO.
2020-05-24 15:56:41,575 [INFO] absl: Constructing tf.data.Dataset for split train[:98%], from /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2


[1mDataset cifar100 downloaded and prepared to /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2. Subsequent calls will reuse this data.[0m


In [26]:
data_test = input_pipeline.get_data(
    dataset=dataset, mode='test',
    repeats=1, batch_size=batch,
    resize_size=resize_size, crop_size=crop_size,
    examples_per_class=1, examples_per_class_seed=0,
    mixup_alpha=None,
    num_devices=num_devices,
    tfds_manual_dir=tfds_manual_dir)

2020-05-24 15:56:41,689 [INFO] absl: Load dataset info from /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2
2020-05-24 15:56:41,692 [INFO] absl: Load dataset info from /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2
2020-05-24 15:56:41,695 [INFO] absl: Reusing dataset cifar100 (/Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2)
2020-05-24 15:56:41,696 [INFO] absl: Constructing tf.data.Dataset for split test, from /Users/ramine.tinati/tensorflow_datasets/cifar100/3.0.2


In [27]:
def reshape_for_keras(features, batch_size, crop_size):
  features["image"] = tf.reshape(features["image"], (batch_size, crop_size, crop_size, 3))
  features["label"] = tf.reshape(features["label"], (batch_size, -1))
  return (features["image"], features["label"])


In [28]:
data_train = data_train.map(lambda x: reshape_for_keras(
    x, batch_size=batch, crop_size=crop_size))
data_test = data_test.map(lambda x: reshape_for_keras(x, batch_size=batch, crop_size=crop_size))

### Set Up Model

In [29]:
with strategy.scope():
    filters_factor = int(model[-1])*4
    tf_model = models.ResnetV2(
        num_units=models.NUM_UNITS[model],
        num_outputs=21843,
        filters_factor=filters_factor,
        name="resnet",
        trainable=True,
        dtype=tf.float32)

    tf_model.build((None, None, None, 3))
    logger.info(f'Loading weights...')
    tf_model.load_weights(bit_model_file)
    logger.info(f'Weights loaded into model!')

    tf_model._head = tf.keras.layers.Dense(
        units=dataset_info['num_classes'],
        use_bias=True,
        kernel_initializer="zeros",
        trainable=True,
        name="head/dense")

    lr_supports = bit_hyperrule.get_schedule(dataset_info['num_examples'])

    schedule_length = lr_supports[-1]
    # NOTE: Let's not do that unless verified necessary and we do the same
    # across all three codebases.
    # schedule_length = schedule_length * 512 / args.batch

    optimizer = tf.keras.optimizers.SGD(momentum=0.9)
    loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

    tf_model.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])


2020-05-24 15:56:43,913 [INFO] bit_common: Loading weights...
2020-05-24 15:56:44,239 [INFO] bit_common: Weights loaded into model!


### Train Model

In [30]:
# tf_model.summary()

In [31]:
class BiTLRSched(tf.keras.callbacks.Callback):
    
    def __init__(self, base_lr, num_samples):
        self.step = 0
        self.base_lr = base_lr
        self.num_samples = num_samples

    def on_train_batch_begin(self, batch, logs=None):
        lr = bit_hyperrule.get_lr(self.step, self.num_samples, self.base_lr)
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        self.step += 1


In [None]:
logger.info(f'Fine-tuning the model...')
tf.io.gfile.makedirs(logdir)
tf.io.gfile.makedirs(bit_pretrained_dir)

steps_per_epoch = eval_every or schedule_length

history = tf_model.fit(
  data_train,
  steps_per_epoch=steps_per_epoch,
  epochs=schedule_length // steps_per_epoch,
  validation_data=data_test,  # here we are only using
                              # this data to evaluate our performance
  callbacks=[BiTLRSched(base_lr, dataset_info['num_examples'])],
)

2020-05-24 15:56:44,296 [INFO] bit_common: Fine-tuning the model...




### Evaluate

In [None]:
for epoch, accu in enumerate(history.history['val_accuracy']):
    logger.info(
            f'Step: {epoch * args.eval_every}, '
            f'Test accuracy: {accu:0.3f}')