##### Copyright 2020 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/google-research/big_transfer/blob/master/colabs/big_transfer_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# BigTransfer (BiT): A step-by-step tutorial for state-of-the-art vision

This colab demonstrates how to:
1. Load BiT models in JAX.
2. Make predictions using BiT pre-trained on CIFAR-10.
3. Fine-tune BiT on 5-shot CIFAR-100 and get amazing results!

It is good to get an understanding or quickly try things. However, to run longer training runs, we recommend using the commandline scripts at http://github.com/google-research/big_transfer

# Install flax and run imports

In [1]:
#@markdown Select whether you would like to store data in your personal drive.
#@markdown
#@markdown If you select **yes**, you will need to authorize Colab to access
#@markdown your personal drive
#@markdown
#@markdown If you select **no**, then any changes you make will diappear when
#@markdown this Colab's VM restarts after some time of inactivity...
use_gdrive = 'yes'  #@param ["yes", "no"]

if use_gdrive == 'yes':
  from google.colab import drive
  drive.mount('/gdrive')
  root = '/gdrive/My Drive/Fall 20-21/COS 454/Project/cnn_txf_bias/big_transfer'
  import os
  if not os.path.isdir(root):
    os.mkdir(root)
  os.chdir(root)
  print(f'\nChanged CWD to "{root}"')
else:
  from IPython import display
  display.display(display.HTML(
      '<h1 style="color:red">CHANGES NOT PERSISTED</h1>'))

Mounted at /gdrive

Changed CWD to "/gdrive/My Drive/Fall 20-21/COS 454/Project/cnn_txf_bias/big_transfer"


In [2]:
!pip install flax
!pip install tfa-nightly
!pip install tensorflow_io

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/c7/c0/941b4d2a2164c677fe665b6ddb5ac90306d76f8ffc298f44c41c64b30f1a/flax-0.3.0-py3-none-any.whl (154kB)
[K     |██▏                             | 10kB 17.3MB/s eta 0:00:01[K     |████▎                           | 20kB 12.9MB/s eta 0:00:01[K     |██████▍                         | 30kB 8.9MB/s eta 0:00:01[K     |████████▌                       | 40kB 7.8MB/s eta 0:00:01[K     |██████████▋                     | 51kB 4.6MB/s eta 0:00:01[K     |████████████▊                   | 61kB 5.2MB/s eta 0:00:01[K     |██████████████▉                 | 71kB 5.2MB/s eta 0:00:01[K     |█████████████████               | 81kB 5.6MB/s eta 0:00:01[K     |███████████████████             | 92kB 6.0MB/s eta 0:00:01[K     |█████████████████████▏          | 102kB 5.8MB/s eta 0:00:01[K     |███████████████████████▎        | 112kB 5.8MB/s eta 0:00:01[K     |█████████████████████████▍      | 122kB 5.8MB/s eta 0:00:01

In [3]:
import io
import re

from functools import partial

import numpy as np

import matplotlib.pyplot as plt
import random

import jax
import jax.numpy as jnp

import flax
import flax.nn as nn
import flax.optim as optim
import flax.jax_utils as flax_utils

# Assert that GPU is available
assert 'Gpu' in str(jax.devices())

import tensorflow as tf
import tensorflow_datasets as tfds

# Architecture and function for transforming BiT weights to JAX to format

In [4]:
def fixed_padding(x, kernel_size):
  pad_total = kernel_size - 1
  pad_beg = pad_total // 2
  pad_end = pad_total - pad_beg

  x = jax.lax.pad(x, 0.0,
                  ((0, 0, 0),
                   (pad_beg, pad_end, 0), (pad_beg, pad_end, 0),
                   (0, 0, 0)))
  return x


def standardize(x, axis, eps):
  x = x - jnp.mean(x, axis=axis, keepdims=True)
  x = x / jnp.sqrt(jnp.mean(jnp.square(x), axis=axis, keepdims=True) + eps)
  return x


class GroupNorm(nn.Module):
  """Group normalization (arxiv.org/abs/1803.08494)."""

  def apply(self, x, num_groups=32):

    input_shape = x.shape
    group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups)

    x = x.reshape(group_shape)

    # Standardize along spatial and group dimensions
    x = standardize(x, axis=[1, 2, 4], eps=1e-5)
    x = x.reshape(input_shape)

    bias_scale_shape = tuple([1, 1, 1] + [input_shape[-1]])
    x = x * self.param('scale', bias_scale_shape, nn.initializers.ones)
    x = x + self.param('bias', bias_scale_shape, nn.initializers.zeros)
    return x


class StdConv(nn.Conv):

  def param(self, name, shape, initializer):
    param = super().param(name, shape, initializer)
    if name == 'kernel':
      param = standardize(param, axis=[0, 1, 2], eps=1e-10)
    return param


class RootBlock(nn.Module):

  def apply(self, x, width):
    x = fixed_padding(x, 7)
    x = StdConv(x, width, (7, 7), (2, 2),
                padding="VALID",
                bias=False,
                name="conv_root")

    x = fixed_padding(x, 3)
    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="VALID")

    return x


class ResidualUnit(nn.Module):
  """Bottleneck ResNet block."""

  def apply(self, x, nout, strides=(1, 1)):
    x_shortcut = x
    needs_projection = x.shape[-1] != nout * 4 or strides != (1, 1)

    group_norm = GroupNorm
    conv = StdConv.partial(bias=False)

    x = group_norm(x, name="gn1")
    x = nn.relu(x)
    if needs_projection:
      x_shortcut = conv(x, nout * 4, (1, 1), strides, name="conv_proj")
    x = conv(x, nout, (1, 1), name="conv1")

    x = group_norm(x, name="gn2")
    x = nn.relu(x)
    x = fixed_padding(x, 3)
    x = conv(x, nout, (3, 3), strides, name="conv2", padding='VALID')

    x = group_norm(x, name="gn3")
    x = nn.relu(x)
    x = conv(x, nout * 4, (1, 1), name="conv3")

    return x + x_shortcut


class ResidualBlock(nn.Module):

  def apply(self, x, block_size, nout, first_stride):
    x = ResidualUnit(
        x, nout, strides=first_stride,
        name="unit01")
    for i in range(1, block_size):
      x = ResidualUnit(
          x, nout, strides=(1, 1),
          name=f"unit{i+1:02d}")
    return x


class ResNet(nn.Module):
  """ResNetV2."""

  def apply(self, x, num_classes=1000,
            width_factor=1, num_layers=50):
    block_sizes = _block_sizes[num_layers]

    width = 64 * width_factor

    root_block = RootBlock.partial(width=width)
    x = root_block(x, name='root_block')

    # Blocks
    for i, block_size in enumerate(block_sizes):
      x = ResidualBlock(x, block_size, width * 2 ** i,
                        first_stride=(1, 1) if i == 0 else (2, 2),
                        name=f"block{i + 1}")

    # Pre-head
    x = GroupNorm(x, name='norm-pre-head')
    x = nn.relu(x)
    x = jnp.mean(x, axis=(1, 2))

    # Head
    x = nn.Dense(x, num_classes, name="conv_head",
                 kernel_init=nn.initializers.zeros)

    return x.astype(jnp.float32)


_block_sizes = {
      50: [3, 4, 6, 3],
      101: [3, 4, 23, 3],
      152: [3, 8, 36, 3],
  }


def transform_params(params, params_tf, num_classes, init_head=False):
  # BiT and JAX models have different naming conventions, so we need to
  # properly map TF weights to JAX weights
  params['root_block']['conv_root']['kernel'] = (
    params_tf['resnet/root_block/standardized_conv2d/kernel'])

  for block in ['block1', 'block2', 'block3', 'block4']:
    units = set([re.findall(r'unit\d+', p)[0] for p in params_tf.keys()
                 if p.find(block) >= 0])
    for unit in units:
      for i, group in enumerate(['a', 'b', 'c']):
        params[block][unit][f'conv{i+1}']['kernel'] = (
          params_tf[f'resnet/{block}/{unit}/{group}/'
                    'standardized_conv2d/kernel'])
        params[block][unit][f'gn{i+1}']['bias'] = (
          params_tf[f'resnet/{block}/{unit}/{group}/'
                    'group_norm/beta'][None, None, None])
        params[block][unit][f'gn{i+1}']['scale'] = (
          params_tf[f'resnet/{block}/{unit}/{group}/'
                    'group_norm/gamma'][None, None, None])

      projs = [p for p in params_tf.keys()
               if p.find(f'{block}/{unit}/a/proj') >= 0]
      assert len(projs) <= 1
      if projs:
        params[block][unit]['conv_proj']['kernel'] = params_tf[projs[0]]

  params['norm-pre-head']['bias'] = (
    params_tf['resnet/group_norm/beta'][None, None, None])
  params['norm-pre-head']['scale'] = (
    params_tf['resnet/group_norm/gamma'][None, None, None])

  if init_head:
    params['conv_head']['kernel'] = params_tf['resnet/head/conv2d/kernel'][0, 0]
    params['conv_head']['bias'] = params_tf['resnet/head/conv2d/bias']
  else:
    params['conv_head']['kernel'] = np.zeros(
      (params['conv_head']['kernel'].shape[0], num_classes), dtype=np.float32)
    params['conv_head']['bias'] = np.zeros(num_classes, dtype=np.float32)

# Run BiT-M-ResNet50x1 already fine-tuned on CIFAR-10

## Build model and load weights

In [7]:
with tf.io.gfile.GFile('gs://bit_models/BiT-M-R50x1-CIFAR10.npz', 'rb') as f:
  params_tf = np.load(f)
params_tf = dict(zip(params_tf.keys(), params_tf.values()))

for k in params_tf:
  params_tf[k] = jnp.array(params_tf[k])

ResNet_cifar10 = ResNet.partial(num_classes=10)

def resnet_fn(params, images):
  return ResNet_cifar10.partial(num_classes=10).call(params, images)

resnet_init = ResNet_cifar10.init_by_shape
_, params = resnet_init(jax.random.PRNGKey(0), [([1, 224, 224, 3], jnp.float32)])

transform_params(params, params_tf, 10, init_head=True)

## Prepare data

In [None]:
data_builder = tfds.builder('cifar10')
data_builder.download_and_prepare()

def _pp(data):
  im = data['image']
  im = tf.image.resize(im, [128, 128])
  im = (im - 127.5) / 127.5
  data['image'] = im
  return {'image': data['image'], 'label': data['label']}

data = data_builder.as_dataset(split='test')
data = data.map(_pp)
data = data.batch(100)
data_iter = data.as_numpy_iterator()

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteGHAXNZ/cifar10-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteGHAXNZ/cifar10-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


## Run BiT

In [None]:
correct, n = 0, 0
for batch in data_iter:
  preds = resnet_fn(params, batch['image'])
  correct += np.sum(np.argmax(preds, axis=1) == batch['label'])
  n += len(preds)

print(f"CIFAR-10 accuracy of BiT-M-R50x1: {correct / n:0.3%}")

### Generate output in CIFAR-10H format

In [11]:
labelnames = dict(
  # https://www.cs.toronto.edu/~kriz/cifar.html
  cifar10=('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'),
  # https://www.cs.toronto.edu/~kriz/cifar.html
  cifar100=('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')
)

dataset = 'cifar10'

In [5]:
test_batch_size = 100

data_builder = tfds.builder('cifar10')
data_builder.download_and_prepare()

def get_data(split, repeats, batch_size, images_per_class, shuffle_buffer):
  data = data_builder.as_dataset(split=split)

  if split == 'train':
    data = data.batch(50000)

    data = data.as_numpy_iterator().next()

    data = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(data['image']),
                                tf.data.Dataset.from_tensor_slices(data['label'])))
    data = data.map(lambda x, y: {'image': x, 'label': y})
  else:
    data = data.map(lambda d: {'image': d['image'], 'label': d['label']})

  def _pp(data):
    im = data['image']
    if split == 'train':
      im = tf.image.resize(im, [160, 160])
      im = tf.image.random_crop(im, [128, 128, 3])
      im = tf.image.flip_left_right(im)
    else:
      im = tf.image.resize(im, [128, 128])
    im = (im - 127.5) / 127.5
    data['image'] = im
    data['label'] = tf.one_hot(data['label'], 10)
    return {'image': data['image'], 'label': data['label']}

  data = data.repeat(repeats)
  data = data.shuffle(shuffle_buffer)
  data = data.map(_pp)
  return data.batch(batch_size)

# data_train = get_data(split='train', repeats=None, images_per_class=5000,
#                       batch_size=test_batch_size, shuffle_buffer=500)
data_test = get_data(split='test', repeats=1, images_per_class=None,
                      batch_size=test_batch_size, shuffle_buffer=1)

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Extraction completed...', max=1.0, styl…









HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteF27LN5/cifar10-train.tfrecord


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteF27LN5/cifar10-test.tfrecord


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


In [14]:
import pickle
import csv

resnet_init = ResNet_cifar10.init_by_shape
_, params = resnet_init(jax.random.PRNGKey(0), [([test_batch_size, 224, 224, 3], jnp.float32)])

# Load baseline model first
file_obj = open(f'./bit_models/imagenet21k+imagenet2012+cifar10/BiT-M-R50x1_Baseline.pkl', 'rb')
params = pickle.load(file_obj)
file_obj.close()

total_images = 10000

count = 0

# CIFAR-10 Test
print('Running CIFAR-10 Test')
for test_im in data_test.as_numpy_iterator():
  
  logits = resnet_fn(params, test_im['image'])

  preds = logits.argmax(axis=1)
  trues = test_im['label'].argmax(axis=-1)

  # for pred, true in zip(preds, trues):
  #   pred_label = labelnames[dataset][pred]
  #   true_label = labelnames[dataset][true]
  #   print(pred_label, true_label)

  with open(f'../cifar-10h/resnet50.csv', mode='a') as csv_file:
    for pred, true in zip(preds, trues):
      pred_label = labelnames[dataset][pred]
      true_label = labelnames[dataset][true]

      csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

      if count == 0: csv_writer.writerow(['annotator_id', 'trial_idx', 'true_category', 'chosen_category', 'true_label', 'chosen_label', 'correct_guess'])
      csv_writer.writerow([0, count, true_label, pred_label, true, pred, f'{1 if pred_label == true_label else 0}'])

      count += 1

  print('\r %0.2f%%' % (count/(total_images)*100), end='')

print()

Running CIFAR-10 Test
 100.00%


In [15]:
import pickle
import csv

resnet_init = ResNet_cifar10.init_by_shape
_, params = resnet_init(jax.random.PRNGKey(0), [([test_batch_size, 224, 224, 3], jnp.float32)])

# Load fine-tuned model
file_obj = open(f'./bit_models/imagenet21k+imagenet2012+cifar10/BiT-M-R50x1_Baseline+Rotate+Cutout+Sobel Filtering+Gaussian Blur+Color Distortion+Gaussain Noise.pkl', 'rb')
params = pickle.load(file_obj)
file_obj.close()

total_images = 10000

count = 0

# CIFAR-10 Test
print('Running CIFAR-10 Test')
for test_im in data_test.as_numpy_iterator():
  
  logits = resnet_fn(params, test_im['image'])

  preds = logits.argmax(axis=1)
  trues = test_im['label'].argmax(axis=-1)

  # for pred, true in zip(preds, trues):
  #   pred_label = labelnames[dataset][pred]
  #   true_label = labelnames[dataset][true]
  #   print(pred_label, true_label)

  with open(f'../cifar-10h/resnet50_ft.csv', mode='a') as csv_file:
    for pred, true in zip(preds, trues):
      pred_label = labelnames[dataset][pred]
      true_label = labelnames[dataset][true]

      csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

      if count == 0: csv_writer.writerow(['annotator_id', 'trial_idx', 'true_category', 'chosen_category', 'true_label', 'chosen_label', 'correct_guess'])
      csv_writer.writerow([0, count, true_label, pred_label, true, pred, f'{1 if pred_label == true_label else 0}'])

      count += 1

  print('\r %0.2f%%' % (count/(total_images)*100), end='')

print()

Running CIFAR-10 Test
 100.00%
