### Intro to JAX
[JAX](https://github.com/google/jax) is a framework which is used for high-performance numerical computing and machine learning research developed at [Google Research](https://research.google/) teams. It allows you to build Python applications with a NumPy-consistent API that specializes in differentiating, vectorizing, parallelizing, and compiling to GPU/TPU Just-In-Time. JAX was designed with performance and speed as a first priority, and is natively compatible with common machine learning accelerators such as [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) and [TPUs](https://www.kaggle.com/docs/tpu). Large ML models can take ages to train -- you might be interested in using JAX for applications where speed and performance are particularly important!
### When to use JAX vs TensorFlow?
[TensorFlow](https://www.tensorflow.org/guide) is a fantastic product, with a rich and fully-featured ecosystem, capable of supporting most every use case a machine learning practitioner might have (e.g. [TFLite](https://www.tensorflow.org/lite) for on-device inference computing, [TFHub](https://tfhub.dev/) for sharing pre-trained models, and many additional specialized applications as well). This type of broad mandate both contrasts and compliments JAX's philosophy, which is more narrowly focused on speed and performance.  We recommend using JAX in situations where you do want to maximize speed and performance but you do not require any of the long tail of features and additional functionalities that only the [TensorFlow ecosystem](https://www.tensorflow.org/learn) can provide.
### Intro to the FLAX
Just like [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) focuses on speed, other members of the JAX ecosystem are encouraged to specialize as well.  For example, [Flax](https://flax.readthedocs.io/en/latest/) focuses on neural networks and [jgraph](https://github.com/deepmind/jraph) focuses on graph networks.  

[Flax](https://flax.readthedocs.io/en/latest/) is a JAX-based neural network library that was initially developed by  Google Research's Brain Team (in close collaboration with the JAX team) but is now open source.  If you want to train machine learning models on GPUs and TPUs at an accelerated speed, or if you have an ML project that might benefit from bringing together both [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), consider using [Flax](https://flax.readthedocs.io/en/latest/) for your next project! [Flax](https://flax.readthedocs.io/en/latest/) is especially well-suited for projects that use large language models, and is a popular choice for cutting-edge [machine learning research](https://arxiv.org/search/?query=JAX&searchtype=all&abstracts=show&order=-announced_date_first&size=50).

### Disclaimer:
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


## Imports
Uncomment and Run this code cell when only accelerator is TPU

In [None]:
#%%capture
#!conda install -y -c conda-forge jax jaxlib flax optax datasets transformers
#!conda install -y importlib-metadata

In [None]:
# Importing all the libraries necessary for the project
import os
import sys
import time
import math
import random
import librosa
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.init as init
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
import tensorflow as tf
from torch.utils.data import Dataset, random_split, DataLoader
import gc
from tqdm import tqdm, tqdm_notebook; tqdm.pandas()
import jax
import torchvision
import optax
import flax.linen as nn
import jax.nn
import jax.numpy as jnp
from flax import linen as nn
from tensorflow.keras.utils import to_categorical
seed = 1234
np.random.seed(seed)
import warnings
def fxn():
    warnings.warn("deprecated", DeprecationWarning)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    fxn()
warnings.filterwarnings("ignore")
from torchvision import transforms
import torch
from typing import Any
import functools
from flax.training import train_state
# to suppress warnings caused by cuda version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

### TPU detection and configuration
**We recommend using [GPUs](https://www.kaggle.com/docs/efficient-gpu-usage) when working with JAX on Kaggle.** These notebooks are compatible with the v3-8 [TPUs](https://www.kaggle.com/docs/tpu) that are provided for free in [Kaggle Notebooks](https://www.kaggle.com/code/new), but JAX was optimized for the newly updated [TPU VM](https://cloud.google.com/blog/products/compute/introducing-cloud-tpu-vms) architecture which is not yet available on Kaggle.


In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None
    
if tpu:
    import requests
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    from jax.config import config
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
    print('No TPU detected.')

In [None]:
# List all the available devices
jax.devices()

## Load and Pre-process the dataset
For time and memory management, we'll be taking random sample of 15 birds

In [None]:
# seeding function for reproducibility
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
ROOT = "/kaggle/input/birdsong-recognition/"
os.listdir(ROOT)

In [None]:
df = pd.read_csv(os.path.join(ROOT, 'train.csv'))[['ebird_code', 'filename', 'duration']]
df['path'] = ROOT+'train_audio/' + df['ebird_code'] + "/" + df['filename']
df.head()

In [None]:
SEED = 42
FRAC = 0.2     # Validation fraction
SR = 44100     # sampling rate
MAXLEN= 60    # seconds
N_MELS = 128

seed_everything(SEED)
device = torch.device('cpu')

#Random sample of 15 birds
classes = set(random.sample(df['ebird_code'].unique().tolist(), 15)) 
print(classes)

In [None]:
df = df[df.ebird_code.apply(lambda x: x in classes)].reset_index(drop=True)
keys = set(df.ebird_code)
values = np.arange(0, len(keys))
code_dict = dict(zip(sorted(keys), values))
df['label'] = df['ebird_code'].apply(lambda x: code_dict[x])
df.head()

Creating custom dataset class using Pytorch's [Datasets and Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#loading-a-dataset)

In [None]:
class BirdSoundDataset(Dataset):
    """Bird Sound dataset."""

    def __init__(self, df, transform = None):
        """
        Args:
            df (pd.DataFrame): must have ['path', 'label'] columns
        """
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    
    def loadMP3(self, path, duration):
        """
        Args:
            path: path of the audio file 
        Returns:
            mels: Melspectrogram of the given audio file 
        """
        try:
            duration=5
            samples = SR* duration
            audio, _ = librosa.load(path, sr=SR)
            
            if 0 < len(audio):
                audio, _ = librosa.effects.trim(audio)
            if len(audio) > samples: # long enough
                audio = audio[0:0+samples]
            else: # pad blank
                padding = samples - len(audio)
                offset = padding // 2
                y = np.pad(audio, (offset, samples - len(audio) - offset), 'constant')

            mels = librosa.feature.melspectrogram(y=audio, sr=SR,n_mels=N_MELS, hop_length = 347,n_fft = N_MELS *20,fmin = 20, fmax = SR//2)
            mels = librosa.power_to_db(mels).astype(np.float32)
            mels = mels.transpose()
            eps = 0.001
            if np.std(mels) != 0:
                mels = (mels - np.mean(mels)) / np.std(mels)
            else:
                mels = (mels - np.mean(mels)) / eps
            return mels
            
        except Exception as e:
            print("Error encountered while parsing file: ", path, e)
            mels = np.zeros((N_MELS, MAXLEN*SR//347), dtype=np.float32)
            return mels
            

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        path = self.df['path'].iloc[idx]
    
        duration=5
        if os.path.exists("./"+path.split('/')[-1]+".npy"):
            mels = np.load("./"+path.split('/')[-1]+".npy")
        else:
            
            mels = self.loadMP3(path, duration)
            np.save("./"+path.split('/')[-1]+".npy", mels)
        label  = self.df['label'].iloc[idx]
        mels = np.resize(mels,(636,128,1))
        return mels, label

In [None]:
# Dividing the dataset into train and validation sets
df = df.sample(frac=1).reset_index(drop=True)
train_len = int(len(df) * (1-FRAC))
train_df = df.iloc[:train_len]
valid_df = df.iloc[train_len:]
train_df.shape, valid_df.shape

The Dataset retrieves our dataset’s features and labels one sample at a time. While training a model, we typically want to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.

DataLoader is an iterable that abstracts this complexity for us in an easy API.

-Source: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#loading-a-dataset

In [None]:
# prepare data loaders 
#NUM_TPUS = jax.device_count()
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(BirdSoundDataset(train_df),
                                           batch_size=BATCH_SIZE, 
                                           num_workers=0, 
                                           shuffle=True, 
                                           drop_last = True)

valid_loader = torch.utils.data.DataLoader(BirdSoundDataset(valid_df), 
                                           batch_size=BATCH_SIZE, 
                                           num_workers=0, 
                                           shuffle=True, 
                                           drop_last = True)

len(train_loader), len(valid_loader)

In [None]:
(image_batch, label_batch) = next(iter(train_loader))
print(image_batch.shape)
print(label_batch.shape)

## Creating batches/shrades of the data

Reading batches of data from the CPU's RAM and copying it to the memory of the accelerator which you're going to use for computation in the form of ShardedDeviceArray(s) using JAX's [`device_put_sharded`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put_sharded.html#jax.device_put_sharded).

In [None]:
NUM_TPUS = jax.device_count()

def copy_dataset_to_devices(dataset, devices, num_reps=1):
    sharded_images = []
    sharded_labels = []
    for _ in range(num_reps):
        for image_batch, label_batch in tqdm(dataset, ncols=100):
            image_batch = image_batch.detach().cpu().numpy()
            image_batches = np.split(image_batch, NUM_TPUS, axis = 0)
            sharded_device_images = jax.device_put_sharded(image_batches, devices)
            sharded_images.append(sharded_device_images)

            label_batch = label_batch.detach().cpu().numpy()
            label_batches = np.split(label_batch, NUM_TPUS, axis = 0)
            sharded_device_labels = jax.device_put_sharded(label_batches, devices)
            sharded_labels.append(sharded_device_labels)

    return sharded_images, sharded_labels

devices = jax.local_devices()
sharded_training_images, sharded_training_labels = copy_dataset_to_devices(train_loader, devices, num_reps=10)

## Model architecture

Here, in this notebook we'll be using VGG19 network. we'll be using [FLAX Linen package](https://flax.readthedocs.io/en/latest/flax.linen.html) for defining the model architecture from scratch.

In [None]:
NUM_CLASSES = 15 
class VGG19(nn.Module):
    @nn.compact
    def __call__(self, x, training):
        x = self._stack(x, 64, training)
        x = self._stack(x, 64, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
    
        x = self._stack(x, 128, training)
        x = self._stack(x, 128, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = self._stack(x, 256, training)
        x = self._stack(x, 256, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    

        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))    
    
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = self._stack(x, 512, training)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))  

        x = x.reshape((x.shape[0], -1))

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)

        x = nn.Dense(features=4096)(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.5, deterministic=not training)(x)
    
        x = nn.Dense(features=NUM_CLASSES)(x)
        return x
  
    @staticmethod
    def _stack(x, features, training, dropout=None):
        x = nn.Conv(features=features, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        return x

## Train function
In the train function, we'll collect a batch of train data by looping through sharded training images and sharded training labels to train our neural network from the given state, and we'll get back our new training state as well as batch statistics.

In [None]:
def average_metrics(metrics):
    '''
    Takes the list of dictionaries of the form k: v, and returns a dictionary
     of the form k: (average of the v).
    '''
    return {k: np.mean([metric[k] for metric in metrics])
        for k in metrics[0]}

def train(initial_network_state, num_epochs):
    '''
    Training the model from the given state, returns the state along with the training accuracies
    '''
    training_accuracies = []
    state = initial_network_state
    for i in range(num_epochs):
        batch_metrics = []
        for (image_batch, label_batch) in tqdm(zip(sharded_training_images,
                                               sharded_training_labels),
                                           total=len(sharded_training_images),
                                           ncols=100):
            state, metrics = train_batch(state, image_batch, label_batch)
            batch_metrics.append(metrics)
        train_metrics = average_metrics(batch_metrics)
        print(f'Epoch {i+1} done.', flush=True)
        print(f'  Loss: {train_metrics["loss"]:.4f}, '
          + f'accuracy: {train_metrics["accuracy"]:.4f}', flush=True)
        training_accuracies.append(train_metrics["accuracy"])
    return state, training_accuracies

### Model Initialization functions

In FLAX, we have to manually create and update `train_state` which holds all the model's variables using [`flax.training.train_state`](https://flax.readthedocs.io/en/latest/_modules/flax/training/train_state.html#TrainState)

In [None]:
class VGGState(train_state.TrainState):
    rng: Any
    batch_stats: Any
  
    @classmethod
    def create(cls, apply_fn, params, tx, rng, batch_stats):
        opt_state = tx.init(params)
        state = cls(0, apply_fn, params, tx, opt_state, rng, batch_stats)
        return state
  
    @classmethod
    def update_rng(cls, state, rng):
        return VGGState.create(state.apply_fn, state.params, state.tx, rng,
                           state.batch_stats)
  
    @classmethod
    def update_batch_stats(cls, state, batch_stats):
        return VGGState.create(state.apply_fn, state.params, state.tx,
                           state.rng, batch_stats)

## Loss & Metrics calculations
Now, we will define the functions which calculates the training loss and the accuracy using the given predicted values and labels

In [None]:
def accuracy(logits, labels):
    '''
    Calcualtes the accuracy using the given logits and labels
    '''
    return jnp.mean(jnp.argmax(logits, -1) == labels)

def cross_entropy(logits, labels):
    '''
    Cross Entropy error between the logits and labels
    '''
    one_hot_labels = jax.nn.one_hot(labels, NUM_CLASSES)
    cross_entropy = optax.softmax_cross_entropy(logits, one_hot_labels)
    return jnp.mean(cross_entropy)

def training_loss(image_batch, label_batch, rng, batch_stats, params):
    '''
    Calculates the training loss 
    '''
    logits, batch_stats = VGG19().apply({'params': params,
                                       'batch_stats': batch_stats},
                                      image_batch, 
                                      training=True,
                                      rngs={'dropout': rng},
                                      mutable=['batch_stats'])
    loss = cross_entropy(logits, label_batch)
    return loss, (logits, batch_stats)

## Training a single batch function
We will now define the function for training a single batch of data, which will take the current train state and the training data as input and return the updated train state along with the training statistics.

In [None]:
@functools.partial(jax.pmap, axis_name='tpu')
def train_batch(state, image_batch, label_batch):
    '''
    Training a single batch and returns loss and the accuracy
    '''
    rng, subrng = jax.random.split(state.rng)
    batch_loss_fn = functools.partial(training_loss, image_batch, label_batch,
                                    subrng, state.batch_stats)
    (batch_loss, (logits, batch_stats)), grads = \
    jax.value_and_grad(batch_loss_fn, has_aux=True)(state.params)
  
    gradsum = jax.lax.psum(grads, axis_name='tpu')

    state = state.apply_gradients(grads=gradsum)
    state = state.update_batch_stats(state, batch_stats['batch_stats'])
    state = state.update_rng(state, rng)

    batch_accuracy = accuracy(logits=logits, labels=label_batch)
    batch_accuracy_sum = jax.lax.pmean(batch_accuracy, axis_name='tpu')
    batch_loss = jax.lax.psum(batch_loss, axis_name='tpu')
    stats = {'loss': batch_loss, 'accuracy': batch_accuracy_sum}  

    return state, stats

## Creating train state
Creating the initial train state which we'll be passing to the neural network while training

In [None]:
def create_train_state(rng, dummy_image_batch):
    net = VGG19()
    params = net.init({'params': rng, 'dropout': rng}, dummy_image_batch, True)
    tx = optax.adam(learning_rate=0.01)
    state = VGGState.create(net.apply, params['params'], tx, rng,
                          params['batch_stats'])
    return state

In [None]:
rng = jax.random.PRNGKey(42)
rngs = np.broadcast_to(rng, (NUM_TPUS,) + rng.shape)
some_dummy_image_batch = sharded_training_images[0]
state = jax.pmap(create_train_state, axis_name='tpu')(rngs,some_dummy_image_batch)

## Training 
Next, we will train the VGG19 neural network for 25 epochs and plot the accuracy graph to see how well the model does.

In [None]:
start = time.time()
final_state, training_accuracies = train(state, num_epochs=25)
print("Total time: ", time.time() - start, "seconds")

In [None]:
# Plot the Accuracy 
plt.plot(training_accuracies)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

### **Conclusion**
Here in this notebook, we've illustrated how [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) can be used to train the neural network from scratch for the audio classification dataset, with an accuracy of more than 95%. To see more examples of how to use [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/) with different data formats, please see this discussion post.  

Now, it's your turn to  create some amazing notebooks using [JAX](https://github.com/google/jax) and [FLAX](https://flax.readthedocs.io/en/latest/). 

### **Useful resources which helped me:**
* https://www.kaggle.com/nilaychauhan/convert-cornell-birdcall-recognition-to-tfrecords
* https://www.kaggle.com/servietsky/fast-import-audio-and-save-spectrograms/notebook
* https://www.kaggle.com/dhananjay3/simple-pytorch-starter/notebook
* https://github.com/google/flax/tree/main/examples/imagenet
* https://flax.readthedocs.io/en/latest/notebooks/annotated_mnist.html
* https://jax.readthedocs.io/en/latest/
* https://gist.github.com/fedelebron/b7be87a4feb88786cc142ef99931ff06#file-dog-classifier-ipynb