### Intro to JAX
[JAX](https://github.com/google/jax) is a framework which is used for high-performance numerical computing and machine learning research developed at [Google Research](https://research.google/) teams. It allows you to build Python applications with a NumPy-consistent API that specializes in differentiating, vectorizing, parallelizing, and compiling to GPU/TPU Just-In-Time. JAX was designed with performance and speed as a first priority, and is natively compatible with common machine learning accelerators such as [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) and [TPUs](https://www.kaggle.com/docs/tpu). Large ML models can take ages to train -- you might be interested in using JAX for applications where speed and performance are particularly important!
### When to use JAX vs TensorFlow?
[TensorFlow](https://www.tensorflow.org/guide) is a fantastic product, with a rich and fully-featured ecosystem, capable of supporting most every use case a machine learning practitioner might have (e.g. [TFLite](https://www.tensorflow.org/lite) for on-device inference computing, [TFHub](https://tfhub.dev/) for sharing pre-trained models, and many additional specialized applications as well). This type of broad mandate both contrasts and compliments JAX's philosophy, which is more narrowly focused on speed and performance.  We recommend using JAX in situations where you do want to maximize speed and performance but you do not require any of the long tail of features and additional functionalities that only the [TensorFlow ecosystem](https://www.tensorflow.org/learn) can provide.
### Intro to the FLAX
Just like [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) focuses on speed, other members of the JAX ecosystem are encouraged to specialize as well.  For example, [Flax](https://flax.readthedocs.io/en/latest/) focuses on neural networks and [jgraph](https://github.com/deepmind/jraph) focuses on graph networks.  

[Flax](https://flax.readthedocs.io/en/latest/) is a JAX-based neural network library that was initially developed by  Google Research's Brain Team (in close collaboration with the JAX team) but is now open source.  If you want to train machine learning models on GPUs and TPUs at an accelerated speed, or if you have an ML project that might benefit from bringing together both [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), consider using [Flax](https://flax.readthedocs.io/en/latest/) for your next project! [Flax](https://flax.readthedocs.io/en/latest/) is especially well-suited for projects that use large language models, and is a popular choice for cutting-edge [machine learning research](https://arxiv.org/search/?query=JAX&searchtype=all&abstracts=show&order=-announced_date_first&size=50).

### Disclaimer:
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


## Imports

In [None]:
#Uncomment and Run when only accelerator is TPU
#%%capture
#!conda install -y -c conda-forge jax jaxlib flax optax
#!conda install -y importlib-metadata

In [None]:
# Importing all the libraries necessary for the project
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import cv2
import time
import random
from random import randint
import time
import torch
from torch.utils.data import Dataset, random_split, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
from scipy import ndimage
import torchvision
from tqdm.notebook import tqdm
import tensorflow as tf
from torchvision import transforms
from tqdm import tqdm
from flax.training import train_state
from typing import Any
import jax.numpy as jnp
import jax.random
import functools
import optax
import flax.linen as nn
import jax.nn
# to suppress warnings caused by cuda version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

### TPU detection and configuration
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


In [None]:
if 'TPU_NAME' in os.environ:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected.')

In [None]:
# Shows the list of the available devices
jax.local_devices()

## Data Reading and Processing
We are using Image dataset from famous Kaggle competition [Dog Breed Identification](https://www.kaggle.com/c/dog-breed-identification). This dataset comprises 120 breeds of dogs.


In [None]:
DATA_DIR = '../input/dog-breed-identification'
TRAIN_DIR = DATA_DIR + '/train'                           
TRAIN_CSV = DATA_DIR + '/labels.csv'     
data_df = pd.read_csv(TRAIN_CSV)

In [None]:
labels_names=data_df["breed"].unique()
labels_sorted=labels_names.sort()
labels = dict(zip(range(len(labels_names)),labels_names))

In [None]:
lbl=[]
path_img=[]

for i in range(len(data_df["breed"])):
    temp1=list(labels.values()).index(data_df.breed[i])
    lbl.append(temp1)
    temp2=TRAIN_DIR + "/" + str(data_df.id[i]) + ".jpg"
    path_img.append(temp2)

data_df['path_img'] =path_img  
data_df['lbl'] = lbl

data_df.head()

Creating custom dataset class using Pytorch's [Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#loading-a-dataset)

In [None]:
class DogDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.transform = transform
        self.root_dir = root_dir
        
    def __len__(self):
        return len(self.df)    
    
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img_id, img_label = row['id'], row['lbl']
        img_fname = self.root_dir + "/" + str(img_id) + ".jpg"
        img = Image.open(img_fname)
        if self.transform:
            img = self.transform(img)
        return img, img_label

As we proceed, we will be applying data augmentations to the dataset using Pytorch's [transforms](https://pytorch.org/vision/stable/transforms.html) and returning data in the form of numpy arrays

In [None]:
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128

training_transform = transforms.Compose([
    transforms.RandomAffine(degrees=(-30, 30),
                            translate=(0.0, 0.2)),
    transforms.RandomHorizontalFlip(),
    transforms.Resize((IMAGE_HEIGHT,
                        IMAGE_WIDTH)),
    np.array])

testing_transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT,
                       IMAGE_WIDTH)),
    np.array])


np.random.seed(42)
msk = np.random.rand(len(data_df)) < 0.8

train_df = data_df[msk].reset_index()
val_df = data_df[~msk].reset_index()

train_ds = DogDataset(train_df, TRAIN_DIR, transform=training_transform)
val_ds = DogDataset(val_df, TRAIN_DIR, transform=testing_transform)
len(train_ds), len(val_ds)

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API.

-Source: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#loading-a-dataset

In [None]:
import jax
NUM_TPUS = jax.device_count()
BATCH_SIZE = 128
train_dataloader = DataLoader(train_ds,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True, drop_last=True,
                                               num_workers=0)
test_dataloader = DataLoader(val_ds,
                                              batch_size=BATCH_SIZE,
                                              shuffle=True, drop_last=True,
                                              num_workers=0)

In [None]:
(image_batch, label_batch) = next(iter(train_dataloader))
print(image_batch.shape)
print(label_batch.shape)

## Creating batches/shrades of the data

Reading batches of data from the CPU's RAM and copying it to the memory of the accelerator which you're going to use for computation in the form of ShardedDeviceArray(s) using JAX's [`device_put_sharded`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put_sharded.html#jax.device_put_sharded).

In [None]:
def copy_dataset_to_devices(dataset, devices, num_reps=1):
    sharded_images = []
    sharded_labels = []
    for _ in range(num_reps):
        for image_batch, label_batch in tqdm(dataset, ncols=100):
            image_batch = image_batch.detach().cpu().numpy()
            image_batches = np.split(image_batch, NUM_TPUS, axis = 0)
            sharded_device_images = jax.device_put_sharded(image_batches, devices)
            sharded_images.append(sharded_device_images)

            label_batch = label_batch.detach().cpu().numpy()
            label_batches = np.split(label_batch, NUM_TPUS, axis = 0)
            sharded_device_labels = jax.device_put_sharded(label_batches, devices)
            sharded_labels.append(sharded_device_labels)

    return sharded_images, sharded_labels

devices = jax.local_devices()
sharded_training_images, sharded_training_labels = copy_dataset_to_devices(train_dataloader, devices, num_reps=10)

## Model architecture

Here, in this notebook we'll be using VGG19 network. we'll be using [FLAX Linen package](https://flax.readthedocs.io/en/latest/flax.linen.html) for defining the model architecture from scratch.

In [None]:
NUM_CLASSES = 120 
class VGG19(nn.Module):
    @nn.compact
    def __call__(self, x, training):
        x = self._stack(x, 64, training)
        x = self._stack(x, 64, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    
        x = self._stack(x, 128, training)
        x = self._stack(x, 128, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = self._stack(x, 256, training)
        x = self._stack(x, 256, training)
        x = self._stack(x, 256, training)
        x = self._stack(x, 256, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    

        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    
    
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))  

        x = x.reshape((x.shape[0], -1))

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)
    
        x = nn.Dense(features=NUM_CLASSES)(x)
        x = nn.log_softmax(x)
        return x
  
    @staticmethod
    def _stack(x, features, training, dropout=None):
        x = nn.Conv(features=features, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        return x

## Train function
In the train function, we'll collect a batch of train data by looping through sharded training images and sharded training labels to train our neural network from the given state, and we'll get back our new training state as well as batch statistics.

In [None]:
def average_metrics(metrics):
    '''
    Takes the list of dictionaries of the form k: v, and returns a dictionary
     of the form k: (average of the v).
    '''
    return {k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]}

def train(initial_network_state, num_epochs):
    '''
    Training the model from the given state, returns the state along with the training accuracies
    '''
    training_accuracies = []
    state = initial_network_state
    for i in range(num_epochs):
        batch_metrics = []
        for (image_batch, label_batch) in tqdm(zip(sharded_training_images,
                                               sharded_training_labels),
                                           total=len(sharded_training_images),
                                           ncols=100):
            state, metrics = train_batch(state, image_batch, label_batch)
            batch_metrics.append(metrics)
        train_metrics = average_metrics(batch_metrics)
        print(f'Epoch {i+1} done.', flush=True)
        print(f'  Loss: {train_metrics["loss"]:.4f}, '
          + f'accuracy: {train_metrics["accuracy"]:.4f}', flush=True)
        training_accuracies.append(train_metrics["accuracy"])
    return state, training_accuracies

### Model Initialization functions

In FLAX, we have to manually create and update `train_state` which holds all the model's variables using [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html#TrainState)

In [None]:
class VGGState(train_state.TrainState):
    rng: Any
    batch_stats: Any
  
    @classmethod
    def create(cls, apply_fn, params, tx, rng, batch_stats):
        opt_state = tx.init(params)
        state = cls(0, apply_fn, params, tx, opt_state, rng, batch_stats)
        return state
  
    @classmethod
    def update_rng(cls, state, rng):
        return VGGState.create(state.apply_fn, state.params, state.tx, rng,
                           state.batch_stats)
  
    @classmethod
    def update_batch_stats(cls, state, batch_stats):
        return VGGState.create(state.apply_fn, state.params, state.tx,
                           state.rng, batch_stats)

## Loss & Metrics calculations
Now, we will define the functions which calculates the training loss and the accuracy using the given predicted values and labels

In [None]:
def accuracy(logits, labels):
    '''
    Calcualtes the accuracy using the given logits and labels
    '''
    return jnp.mean(jnp.argmax(logits, -1) == labels)

def cross_entropy(logits, labels):
    '''
    Cross Entropy error between the logits and labels
    '''
    one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
    cross_entropy = optax.softmax_cross_entropy(logits, one_hot_labels)
    return jnp.mean(cross_entropy)

def training_loss(image_batch, label_batch, rng, batch_stats, params):
    '''
    Calculates the training loss 
    '''
    logits, batch_stats = VGG19().apply({'params': params,
                                       'batch_stats': batch_stats},
                                      image_batch, 
                                      training=True,
                                      rngs={'dropout': rng},
                                      mutable=['batch_stats'])
    loss = cross_entropy(logits, label_batch)
    return loss, (logits, batch_stats)

## Training a single batch function
We will now define the function for training a single batch of data, which will take the current train state and the training data as input and return the updated train state along with the training statistics.

In [None]:
@functools.partial(jax.pmap, axis_name='tpu')
def train_batch(state, image_batch, label_batch):
    rng, subrng = jax.random.split(state.rng)
    batch_loss_fn = functools.partial(training_loss, image_batch, label_batch,
                                    subrng, state.batch_stats)
    (batch_loss, (logits, batch_stats)), grads = \
    jax.value_and_grad(batch_loss_fn, has_aux=True)(state.params)
  
    gradsum = jax.lax.psum(grads, axis_name='tpu')
    
    state = state.apply_gradients(grads=gradsum)
    state = state.update_batch_stats(state, batch_stats['batch_stats'])
    state = state.update_rng(state, rng)

    batch_accuracy = accuracy(logits=logits, labels=label_batch)
    batch_accuracy_sum = jax.lax.pmean(batch_accuracy, axis_name='tpu')
    batch_loss = jax.lax.psum(batch_loss, axis_name='tpu')
    stats = {'loss': batch_loss, 'accuracy': batch_accuracy_sum}  
    return state, stats

## Creating train state
Creating the initial train state which we'll be passing to the neural network while training

In [None]:
def create_train_state(rng, dummy_image_batch):
    net = VGG19()
    params = net.init({'params': rng, 'dropout': rng}, dummy_image_batch, True)
    tx = optax.adam(learning_rate=0.001)
    state = VGGState.create(net.apply, params['params'], tx, rng,
                          params['batch_stats'])
    return state

In [None]:
rng = jax.random.PRNGKey(42)
rngs = np.broadcast_to(rng, (NUM_TPUS,) + rng.shape)
some_dummy_image_batch = sharded_training_images[0]
state = jax.pmap(create_train_state, axis_name='tpu')(rngs,some_dummy_image_batch)

## Training 
Next, we will train the VGG19 neural network for 75 epochs and plot the accuracy graph to see how well the model does.

In [None]:
start = time.time()
final_state, training_accuracies = train(state, num_epochs=75)
print("Total time: ", time.time() - start, "seconds")

In [None]:
# Plot the Accuracy 
plt.plot(training_accuracies)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

### **Conclusion**
Here in this notebook, we've illustrated how [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) can be used to train the neural network from scratch for the image classification dataset, with an accuracy of more than 80%. To see more examples of how to use [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) with different data formats, please see this discussion post.  

Now, it's your turn to  create some amazing notebooks using [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/). 

### **Useful resources which helped me:**

* https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html
* https://jax.readthedocs.io/en/latest/
* https://gist.github.com/fedelebron/b7be87a4feb88786cc142ef99931ff06#file-dog-classifier-ipynb
* https://www.kaggle.com/anujverma126/dog-breed-classification-beginner-s-tutorial 
* https://github.com/google/flax/tree/main/examples/imagenet