```
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

 https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
```

# Aggregating Nested Transformer

https://arxiv.org/pdf/2105.12723.pdf

This colab shows how to 

- check data pipelines
- load pretrained checkpoints for inference 
- train CIFAR for steps

## Setup

In [None]:
![ -d nested-transformer ] || git clone --depth=1 https://github.com/google-research/nested-transformer
!cd nested-transformer && git pull

In [None]:
!pip install -qr nested-transformer/requirements.txt

### (Optional) Connect to TPU

Set Colab: Runtime -> Change runtime type -> TPU

In [None]:
USE_TPU = False

if USE_TPU:
  # Google Colab "TPU" runtimes are configured in "2VM mode", meaning that JAX
  # cannot see the TPUs because they're not directly attached. Instead we need to
  # setup JAX to communicate with a second machine that has the TPUs attached.
  import os
  if 'google.colab' in str(get_ipython()) and 'COLAB_TPU_ADDR' in os.environ:
    import jax
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print('Connected to TPU.')
  else:
    print('No TPU detected. Can be changed under "Runtime/Change runtime type".')

## Import

In [2]:
import sys
sys.path.append('./nested-transformer')

import os
import time
import flax
from flax import nn
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import functools
from absl import logging


from libml import input_pipeline 
from libml import preprocess
from models import nest_net  
import train  
from configs import cifar_nest 
from configs import imagenet_nest  

# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.experimental.set_visible_devices([], "GPU")
logging.set_verbosity(logging.INFO)

print("JAX devices:\n" + "\n".join([repr(d) for d in jax.devices()]))
print('Current folder content', os.listdir())

2023-09-28 19:32:18.953627: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2023-09-28 19:32:18.953796: E tensorflow/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 535.86.10 does not match DSO version 535.104.5 -- cannot find working devices in this configuration
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
INFO:absl:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.


JAX devices:
CpuDevice(id=0)
Current folder content ['main.py', 'train.py', 'colab.ipynb', 'requirements.txt', 'one_image_inference.py', 'configs', '__pycache__', 'checkpoints', 'README.md', 'models', 'CONTRIBUTING.md', 'augment', 'libml', 'LICENSE']


## ImageNet

### Download checkpoints and load model

In [1]:
checkpoint_dir = "./nested-transformer/checkpoints/"
remote_checkpoint_dir = "gs://gresearch/nest-checkpoints/nest-b_imagenet"
print('List checkpoints: ')
!gsutil ls "$remote_checkpoint_dir"

List checkpoints: 
/bin/bash: line 1: gsutil: command not found


In [None]:
print('Download checkpoints: ')
!mkdir -p "$checkpoint_dir"
!gsutil cp -r "$remote_checkpoint_dir" "$checkpoint_dir".

In [None]:
# Use checkpoint of host 0.
imagenet_config = imagenet_nest.get_config()

state_dict = train.checkpoint.load_state_dict(
    os.path.join(checkpoint_dir, os.path.basename(remote_checkpoint_dir)))
variables = {
    "params": state_dict["optimizer"]["target"],
}
variables.update(state_dict["model_state"])
model_cls = nest_net.create_model(imagenet_config.model_name, imagenet_config)
model = functools.partial(model_cls, num_classes=1000)

### Inference on a single image

In [None]:
import PIL

!wget https://picsum.photos/id/237/200/300 -O dog.jpg
img = PIL.Image.open('dog.jpg')
img

In [None]:
def predict(image):
  logits = model(train=False).apply(variables, image, mutable=False)
  # Return predicted class and confidence.
  return logits.argmax(axis=-1), nn.softmax(logits, axis=-1).max(axis=-1)

def _preprocess(image):
  image = np.array(image.resize((224, 224))).astype(np.float32) / 255
  mean = np.array(preprocess.IMAGENET_DEFAULT_MEAN).reshape(1, 1, 3)
  std = np.array(preprocess.IMAGENET_DEFAULT_STD).reshape(1, 1, 3)
  image = (image - mean) / std
  return image[np.newaxis,...]

input = _preprocess(img)

cls, prob = predict(input)
print(f'ImageNet class id: {cls[0]}, prob: {prob[0]}')

## CIFAR

### Inspect input pipeline of CIFAR

Use the cifar image augmentations

In [None]:
cifar_builder = tfds.builder("cifar10")

In [None]:
config = cifar_nest.get_config()
# Do not apply MixUp or CutMix operations since tfds.visualization.show_examples 
# only accepts integer labels
config.mix = None 

info, train_ds, eval_ds = input_pipeline.create_datasets(
    config, jax.random.PRNGKey(0)
)
_ = tfds.visualization.show_examples(train_ds.unbatch().unbatch(), cifar_builder.info)

### Running a single training step on CIFAR


In [None]:
config = cifar_nest.get_config()
config.num_train_steps = 1
config.num_eval_steps = 1
config.num_epochs = 1
config.warmup_epochs = 0
config.per_device_batch_size = 128 # Set to smaller batch size to avoid OOM
workdir = f"./nested-transformer/checkpoints/cifar_nest_colab_{int(time.time())}"

In [None]:
# Re-create datasets with possibly updated config.
info, train_ds, eval_ds = input_pipeline.create_datasets(
    config, jax.random.PRNGKey(0)
)

In [None]:
train.train_and_evaluate(config, workdir)