See code at https://github.com/google-research/vision_transformer/

See paper at https://arxiv.org/abs/2010.11929

This Colab allows you to run the [JAX](https://jax.readthedocs.org) implementation of the Vision Transformer.

##### Copyright 2020 Google LLC.

In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<a href="https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Setup

Needs to be executed once in every VM.

The cell below downloads the code from Github and install necessary dependencies.

In [2]:
#@markdown Select whether you would like to store data in your personal drive.
#@markdown
#@markdown If you select **yes**, you will need to authorize Colab to access
#@markdown your personal drive
#@markdown
#@markdown If you select **no**, then any changes you make will diappear when
#@markdown this Colab's VM restarts after some time of inactivity...
use_gdrive = 'yes'  #@param ["yes", "no"]

if use_gdrive == 'yes':
  from google.colab import drive
  drive.mount('/gdrive')
  root = '/gdrive/My Drive/Fall 20-21/COS 454/Project/cnn_txf_bias/vision_transformer'
  import os
  if not os.path.isdir(root):
    os.mkdir(root)
  os.chdir(root)
  print(f'\nChanged CWD to "{root}"')
else:
  from IPython import display
  display.display(display.HTML(
      '<h1 style="color:red">CHANGES NOT PERSISTED</h1>'))

Mounted at /gdrive

Changed CWD to "/gdrive/My Drive/Fall 20-21/COS 454/Project/cnn_txf_bias/vision_transformer"


In [3]:
!pip install -qr ./vit_jax/requirements_new.txt
!pip install tfa-nightly
!pip install tensorflow_io
!pip install tfds-nightly

[K     |████████████████████████████████| 61kB 2.8MB/s 
[K     |████████████████████████████████| 153kB 5.4MB/s 
[K     |████████████████████████████████| 92kB 4.5MB/s 
[K     |████████████████████████████████| 144.5MB 88kB/s 
[K     |████████████████████████████████| 4.3MB 49.5MB/s 
[?25hCollecting tfa-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/fd/bb/4cd4b92207c52e00382a06370f06b3ec4cdb8f2cefc6c53090a6716a978b/tfa_nightly-0.13.0.dev20210119020053-cp36-cp36m-manylinux2010_x86_64.whl (706kB)
[K     |████████████████████████████████| 716kB 4.4MB/s 
Installing collected packages: tfa-nightly
Successfully installed tfa-nightly-0.13.0.dev20210119020053
Collecting tensorflow_io
[?25l  Downloading https://files.pythonhosted.org/packages/07/3c/b45c30448cd6a04f25b088da024229149323fa44bc6322a7372bb556eada/tensorflow_io-0.17.0-cp36-cp36m-manylinux2010_x86_64.whl (25.3MB)
[K     |████████████████████████████████| 25.3MB 1.4MB/s 
Installing collected packages: tens

### Imports

In [4]:
# Shows all available pre-trained models.
!gsutil ls -lh gs://vit_models/*

      65 B  2020-10-21T07:59:00Z  gs://vit_models/README.txt

gs://vit_models/imagenet21k+imagenet2012/:
377.57 MiB  2020-11-30T16:17:02Z  gs://vit_models/imagenet21k+imagenet2012/R50+ViT-B_16.npz
330.29 MiB  2020-10-29T17:05:52Z  gs://vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz
 331.4 MiB  2020-10-20T11:48:22Z  gs://vit_models/imagenet21k+imagenet2012/ViT-B_16.npz
336.89 MiB  2020-10-20T11:47:36Z  gs://vit_models/imagenet21k+imagenet2012/ViT-B_32.npz
  1.13 GiB  2020-10-29T17:08:31Z  gs://vit_models/imagenet21k+imagenet2012/ViT-L_16-224.npz
  1.14 GiB  2020-10-20T11:53:44Z  gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz
  1.14 GiB  2020-10-20T11:50:56Z  gs://vit_models/imagenet21k+imagenet2012/ViT-L_32.npz

gs://vit_models/imagenet21k/:
439.85 MiB  2020-11-30T10:10:15Z  gs://vit_models/imagenet21k/R50+ViT-B_16.npz
393.69 MiB  2020-10-22T21:38:39Z  gs://vit_models/imagenet21k/ViT-B_16.npz
400.01 MiB  2020-11-02T08:30:56Z  gs://vit_models/imagenet21k/ViT-B_32.npz
  2.46 

In [5]:
# Specify model
model = 'ViT-B_32'

In [6]:
#@markdown TPU setup : Boilerplate for connecting JAX to TPU.

import os
if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
  # Make sure the Colab Runtime is set to Accelerator: TPU.
  import requests
  if 'TPU_DRIVER_MODE' not in globals():
    url = 'http://' + os.environ['COLAB_TPU_ADDR'].split(':')[0] + ':8475/requestversion/tpu_driver0.1-dev20191206'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

  # The following is required to use TPU Driver as JAX's backend.
  from jax.config import config
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']
  print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
  print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

No TPU detected. Can be changed under "Runtime/Change runtime type".


In [7]:
import flax
import jax
from matplotlib import pyplot as plt
import numpy as np
import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
import flax.jax_utils as flax_utils

# Shows the number of available devices.
# In a CPU/GPU runtime this will be a single device.
# In a TPU runtime this will be 8 cores.
jax.local_devices()



[CpuDevice(id=0)]

In [8]:
# Import files from repository.
# Updating the files in the editor on the right will immediately update the
# modules by re-importing them.

import sys
if './' not in sys.path:
  sys.path.append('./')

%load_ext autoreload
%autoreload 2

from vit_jax import checkpoint
from vit_jax import hyper
from vit_jax import input_pipeline
from vit_jax import logging
from vit_jax import models
from vit_jax import momentum_clip
from vit_jax import train

logger = logging.setup_logger('./logs')

In [9]:
VisionTransformer = models.KNOWN_MODELS[model].partial(num_classes=1000)

# Load and convert pretrained checkpoint.
params = checkpoint.load(f'./vit_models/imagenet21k+imagenet2012/{model}.npz')
params['pre_logits'] = {}  # Need to restore empty leaf for Flax.

# Load and convert fine-tuned model on augmented dataset
params = checkpoint.load(f'./vit_models/imagenet21k+imagenet2012+imagenet2012/ViT-B_32_Baseline+Rotate+Cutout+Sobel Filtering+Gaussian Blur+Color Distortion+Gaussain Noise.npz')
params['pre_logits'] = {}  # Need to restore empty leaf for Flax.

In [10]:
# Get imagenet labels.
# !wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))

In [11]:
airplane_indices = [404]
bear_indices = [294, 295, 296, 297]
bicycle_indices = [444, 671]
bird_indices = [8, 10, 11, 12, 13, 14, 15, 16, 18, 19, 20, 22, 23,
                24, 80, 81, 82, 83, 87, 88, 89, 90, 91, 92, 93,
                94, 95, 96, 98, 99, 100, 127, 128, 129, 130, 131,
                132, 133, 135, 136, 137, 138, 139, 140, 141, 142,
                143, 144, 145]
boat_indices = [472, 554, 625, 814, 914]
bottle_indices = [440, 720, 737, 898, 899, 901, 907]
car_indices = [436, 511, 817]
cat_indices = [281, 282, 283, 284, 285, 286]
chair_indices = [423, 559, 765, 857]
clock_indices = [409, 530, 892]
dog_indices = [152, 153, 154, 155, 156, 157, 158, 159, 160, 161,
                162, 163, 164, 165, 166, 167, 168, 169, 170, 171,
                172, 173, 174, 175, 176, 177, 178, 179, 180, 181,
                182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
                193, 194, 195, 196, 197, 198, 199, 200, 201, 202,
                203, 205, 206, 207, 208, 209, 210, 211, 212, 213,
                214, 215, 216, 217, 218, 219, 220, 221, 222, 223,
                224, 225, 226, 228, 229, 230, 231, 232, 233, 234,
                235, 236, 237, 238, 239, 240, 241, 243, 244, 245,
                246, 247, 248, 249, 250, 252, 253, 254, 255, 256,
                257, 259, 261, 262, 263, 265, 266, 267, 268]
elephant_indices = [385, 386] 
keyboard_indices = [508, 878]
knife_indices = [499]
oven_indices = [766]
truck_indices = [555, 569, 656, 675, 717, 734, 864, 867]

category_indices = [airplane_indices, bear_indices, bicycle_indices, bird_indices, boat_indices,
                    bottle_indices, car_indices, cat_indices, chair_indices, clock_indices,
                    dog_indices, elephant_indices, keyboard_indices, knife_indices,
                    oven_indices, truck_indices]

In [16]:
import cv2
import csv

exp = 'cue-conflict'
categories = os.listdir("./stimuli/" + exp + "/")
categories.sort()
# print(categories)

obj_response = []
obj_category = []
image_name = []

print(f'Prediction for model: "{model}" on experiment: "{exp}"')

count = 0
for c in categories:
  for im in os.listdir("./stimuli/" + exp + "/" + c + "/"):
    image_name.append(im)
    img = cv2.imread("./stimuli/" + exp + "/" + c + "/" + im)
    img = cv2.resize(img, (384, 384))
    inp = (np.array(img) / 128 - 1)[None, ...]
    logits, = VisionTransformer.call(params, inp)
    preds = flax.nn.softmax(logits)

    preds_16 = np.zeros(16)
    for idx in range(1000):
      for ci in range(len(category_indices)):
        if idx in category_indices[ci]:
          preds_16[ci] += preds[idx]

    # print(preds_16)
    # pred = preds.argsort()[-1]
    # print(pred)

    obj_category.append(c)
    obj_response.append(categories[preds_16.argsort()[-1]])

    # obj_resp = ''

    # for ci in range(len(category_indices)):
    #   if pred in category_indices[ci]:
    #     obj_resp = categories[ci]
    
    # if obj_resp == '':
    #   obj_response.append('knife')
    # else:
    #   obj_response.append(obj_resp)

    count = count + 1
    print('\r %0.2f%%' % (count/(16*len(os.listdir("./stimuli/" + exp + "/" + c + "/")))*100), end='')

# Predict on a batch with a single item (note very efficient TPU usage...)
# logits, = VisionTransformer.call(params, (np.array(img) / 128 - 1)[None, ...])

Prediction for model: "ViT-B_32" on experiment: "cue-conflict"
 100.00%

In [17]:
with open(f'./results/fine-tune/imagenet2012/bias-tests/texture-shape_{exp}/texture-shape_{exp}_{model}_ft_session-1.csv', mode='w') as csv_file:
    csv_writer = csv.writer(csv_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)

    csv_writer.writerow(['subj', 'session', 'trial', 'rt', 'object_response', 'category', 'condition', 'imagename'])
    for e in range(len(obj_response)):
      csv_writer.writerow([f'{model}_ft', '1', f'{e+1}', 'NaN', obj_response[e], obj_category[e], 'NaN', image_name[e]])