In [1]:
import os
from typing import Any
from functools import partial

import jax
from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp

import flax
from flax import core
import flax.linen as nn
from flax.training import common_utils

import optax
import msgpack

import tensorflow as tf
import tensorflow_datasets as tfds

import math
import wandb
import numpy as np
from tqdm import tqdm

from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient


print("Number of devices:", jax.device_count())

[percpu.cc : 552] RAW: rseq syscall failed with errno 1
  from .autonotebook import tqdm as notebook_tqdm


Number of devices: 8


In [2]:
os.environ['WANDB_API_KEY'] = UserSecretsClient().get_secret("wandb_api_key")
wandb.init(project="transfer-learning-jax", entity="ml-colabs", job_type="train")

config = wandb.config

config.seed = 0

config.classes = [
    'pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',         'wild geranium',      'tiger lily',
    'moon orchid',      'bird of paradise',          'monkshood',        'globe thistle',     'snapdragon',         "colt's foot",
    'king protea',      'spear thistle',             'yellow iris',      'globe-flower',      'purple coneflower',  'peruvian lily',
    'balloon flower',   'giant white arum lily',     'fire lily',        'pincushion flower', 'fritillary',         'red ginger',
    'grape hyacinth',   'corn poppy',                'prince of wales feathers',              'stemless gentian',   'artichoke',
    'sweet william',    'carnation',                 'garden phlox',     'love in the mist',  'cosmos',             'alpine sea holly',
    'ruby-lipped cattleya', 'cape flower',           'great masterwort', 'siam tulip',        'lenten rose',        'barberton daisy',
    'daffodil',         'sword lily',                'poinsettia',       'bolero deep blue',  'wallflower',         'marigold',
    'buttercup',        'daisy',                     'common dandelion', 'petunia',           'wild pansy',         'primula',
    'sunflower',        'lilac hibiscus',            'bishop of llandaff',                    'gaura',              'geranium',
    'orange dahlia',    'pink-yellow dahlia',        'cautleya spicata', 'japanese anemone',  'black-eyed susan',   'silverbush',
    'californian poppy',                             'osteospermum',     'spring crocus',     'iris',               'windflower',
    'tree poppy',       'gazania',                   'azalea',           'water lily',        'rose',               'thorn apple',
    'morning glory',    'passion flower',            'lotus',            'toad lily',         'anthurium',          'frangipani',
    'clematis',         'hibiscus',                  'columbine',        'desert-rose',       'tree mallow',        'magnolia',
    'cyclamen ',        'watercress',                'canna lily',       'hippeastrum ',      'bee balm',           'pink quill',
    'foxglove',         'bougainvillea',             'camellia',         'mallow',            'mexican petunia',    'bromelia',
    'blanket flower',   'trumpet creeper',           'blackberry lily',  'common tulip',      'wild rose'
]
config.num_classes = 2**(math.ceil(math.log2(len(config.classes))))

config.image_size = 512
config.data_path = os.path.join(
    "/kaggle/input/tpu-getting-started", f'tfrecords-jpeg-{config.image_size}x{config.image_size}'
)

config.efficientnet_v2_path = "/kaggle/input/efficientnetv2-jax/efficientnetv2-m.msgpack"
config.efficientnet_stem_size = 24
config.efficientnet_arch_configs = [
    [1,  24,  3, 1, 0],
    [4,  48,  5, 2, 0],
    [4,  80,  5, 2, 0],
    [4, 160,  7, 2, 1],
    [6, 176, 14, 1, 1],
    [6, 304, 18, 2, 1],
    [6, 512,  5, 1, 1],
]

config.batch_size = 8
config.num_devices = jax.device_count()
config.learning_rate = 7e-5 * config.num_devices
config.weight_decay = 1e-4
config.epochs = 5

[34m[1mwandb[0m: Currently logged in as: [33mgeekyrakshit[0m ([33mml-colabs[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
TRAIN_FILES = tf.io.gfile.glob(os.path.join(config.data_path, 'train', '*.tfrec'))
VAL_FILES = tf.io.gfile.glob(os.path.join(config.data_path, 'val', '*.tfrec'))
TEST_FILES = tf.io.gfile.glob(os.path.join(config.data_path, 'test', '*.tfrec'))

AUTOTUNE = tf.data.AUTOTUNE

print("Number of Train files:", len(TRAIN_FILES))
print("Number of Validation files:", len(VAL_FILES))
print("Number of Test files:", len(TEST_FILES))

Number of Train files: 16
Number of Validation files: 16
Number of Test files: 16


In [4]:
def decode_image(image_data):
    """Given a `tf.string`, returns a legible `tf.tensor`"""
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.reshape(image, [config.image_size, config.image_size, 3]) 
    return image


def read_labeled_tfrecord(example):
    """
    Parses general type example.
    Returns data sample (tuple of (image : ), (label : ))
    """
    labeled_format = {
        # tf.string means bytestring
        "image": tf.io.FixedLenFeature([], tf.string),
        # shape [] means single element
        "class": tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_example = tf.io.parse_single_example(example, labeled_format)
    image = decode_image(parsed_example['image'])
    label = tf.cast(parsed_example['class'], tf.int32)
    return {'image': image, 'label': label}


def read_unlabeled_tfrecord(example):
    """
    Parses general type example, useful for parsing test tfrecord files.
    Returns data sample (tuple of (image : ), (label : ))
    """
    unlabeled_format = {
        "image": tf.io.FixedLenFeature([], tf.string), 
        "id": tf.io.FixedLenFeature([], tf.string),
    }
    parsed_example = tf.io.parse_single_example(example, unlabeled_format)
    image = decode_image(parsed_example['image'])
    idnum = parsed_example['id']
    return {'image': image, 'id': idnum}


def normalize(sample):
    """Given a parsed tfrecord sample, returns the same sample with a normalized image."""
    sample['image'] = tf.cast(sample['image'], tf.float32) / 128. - 1.
    return sample


def to_jax(sample):
    """Given a parsed tfrecord sample, converts it to JAX-pipeline-compatible format."""
    sample['image'] = jnp.array(sample['image'], dtype=jnp.bfloat16)
    sample['label'] = jnp.array(sample['label'], dtype=jnp.int16)
    # Convert labels to one_hot
    sample['label'] = jax.nn.one_hot(
        sample['label'], config.num_classes, dtype=jnp.int16, axis=-1
    )
    return common_utils.shard(sample)

In [5]:
def create_dataloader(
    tfrecord_files,
    is_labeled: bool,
    is_ordered: bool,
    shuffle_buffer_size: int,
    drop_remainder: bool
):
    dataset = tf.data.TFRecordDataset(tfrecord_files, num_parallel_reads=AUTOTUNE)
    
    options = tf.data.Options()
    if not is_ordered:
        options.experimental_deterministic = False
    dataset = dataset.with_options(options)
    
    
    if is_labeled:
        dataset = dataset.map(read_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
    else:
        dataset = dataset.map(read_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(
        config.batch_size * config.num_devices, drop_remainder=drop_remainder
    )
    dataset = dataset.map(normalize)
    dataset = dataset.prefetch(AUTOTUNE)
    
    dataset = tfds.as_numpy(dataset)
    dataset = map(to_jax, dataset)
    
    return dataset

In [6]:
train_loader = create_dataloader(
    TRAIN_FILES,
    is_labeled=True,
    is_ordered=False,
    shuffle_buffer_size=4 * config.batch_size,
    drop_remainder=True
)

val_loader = create_dataloader(
    VAL_FILES,
    is_labeled=True,
    is_ordered=True,
    shuffle_buffer_size=4 * config.batch_size,
    drop_remainder=True
)

## Efficientnet-V2

In [7]:
conv_init = nn.initializers.variance_scaling(
    2., mode='fan_out',
    distribution="truncated_normal",
    dtype=jnp.bfloat16
)
dense_init = nn.initializers.variance_scaling(
    1./3, mode='fan_out',
    distribution="truncated_normal",
    dtype=jnp.bfloat16
)


def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

In [8]:
class SqueezeExcitationLayer(nn.Module):
    input_channels : int
    output_channels : int
    dtype : Any
    reduction : int = 4
        
    @nn.compact 
    def __call__(self, x):
        features = _make_divisible(self.input_channels // self.reduction, 4)
        y = nn.Conv(
            features,
            kernel_size=(1, 1),
            param_dtype=self.dtype,
            dtype=self.dtype
        )(x)
        y = jax.nn.silu(y)
        y = nn.Conv(
            self.output_channels,
            kernel_size=(1, 1),
            param_dtype=self.dtype,
            dtype=self.dtype
        )(y)
        y = jax.nn.sigmoid(y)
        return x * y

In [9]:
class ConvBlock(nn.Module):
    output_channels : int
    kernel_size : int
    dtype : Any 
    stride : int = 1
    groups : int = 1
    dropout_rate : float = 0.
    # Whether to do activation
    use_activation : bool = True
    apply_skip_connection : bool = False
        
    @nn.compact 
    def __call__(self, x):
        if self.apply_skip_connection:
            shortcut = x
        x = nn.Conv(
            self.output_channels,
            kernel_size=(self.kernel_size, self.kernel_size), 
            strides=self.stride,
            feature_group_count=self.groups,
            use_bias=False,
            param_dtype=self.dtype,
            dtype=self.dtype,
            kernel_init=conv_init
        )(x)
        mutable = self.is_mutable_collection('batch_stats')
        if self.dropout_rate != 0.:
            x = nn.Dropout(
                rate=self.dropout_rate, deterministic=not mutable
            )(x)
        x = nn.BatchNorm(
            momentum=.9,
            use_running_average=not mutable,
            param_dtype=self.dtype,
            dtype=self.dtype,
            axis_name='devices'
        )(x)
        if self.use_activation:
            x = jax.nn.silu(x)
        if self.apply_skip_connection:
            return x + shortcut 
        else:
            return x

In [10]:
class DropBlock(nn.Module):
    dropblock_rate : float = .0
    @nn.compact 
    def __call__(self, x):
        mutable = self.is_mutable_collection('batch_stats')
        if mutable:
            pred = jax.random.bernoulli(
                self.make_rng('dropout'), p=self.dropblock_rate
            )
            return jax.lax.cond(pred, lambda x : 0., lambda x: 1., 0) * x
        else:
            return (1 - self.dropblock_rate) * x

In [11]:
class MBConv(nn.Module):
    input_channels : int
    output_channels : int 
    stride : int
    expand_ratio : float 
    use_squeeze_excitation : bool 
    dtype : Any 
    dropout_rate : float = 0. 
    dropblock_rate : float = 0.
        
    def setup(self):
        assert self.stride in [1, 2]

        hidden_dim = round(self.input_channels * self.expand_ratio)
        self.identity = self.stride == 1 and self.input_channels == self.output_channels
        if self.dropblock_rate != 0:
            self.DropBlock = DropBlock(dropblock_rate=self.dropblock_rate)
        if self.use_squeeze_excitation:
            self.conv = nn.Sequential([
                ConvBlock(
                    output_channels=hidden_dim,
                    kernel_size=1,
                    stride=1,
                    dtype=self.dtype,
                    dropout_rate=self.dropout_rate
                ),
                ConvBlock(
                    output_channels=hidden_dim,
                    kernel_size=3,
                    stride=self.stride,
                    groups=hidden_dim,
                    dtype=self.dtype,
                    dropout_rate=self.dropout_rate
                ),
                SqueezeExcitationLayer(
                    input_channels=self.input_channels,
                    output_channels=hidden_dim,
                    dtype=self.dtype
                ),
                nn.Conv(
                    self.output_channels,
                    kernel_size=(1, 1),
                    use_bias=False,
                    param_dtype=self.dtype,
                    dtype=self.dtype,
                    kernel_init=conv_init),
            ])
        else:
            self.conv = nn.Sequential([
                ConvBlock(
                    output_channels=hidden_dim,
                    kernel_size=3,
                    stride=self.stride,
                    dtype=self.dtype,
                    dropout_rate=self.dropout_rate
                ),
                nn.Conv(
                    self.output_channels,
                    kernel_size=(1, 1),
                    use_bias=False,
                    param_dtype=self.dtype,
                    dtype=self.dtype,
                    kernel_init=conv_init
                ),
            ])
        self.bn = nn.BatchNorm(
            momentum=.9,
            param_dtype=self.dtype,
            dtype=self.dtype,
            axis_name='devices'
        )

    # Well this .remat almost is not doing anything
    @nn.remat
    def __call__(self, x):
        mutable = self.is_mutable_collection('batch_stats')
        if self.identity:
            if self.dropblock_rate != 0:
                return x + self.DropBlock(
                    self.bn(self.conv(x), use_running_average=not mutable)
                )
            else:
                return x + self.bn(self.conv(x), use_running_average=not mutable)
        else:
            return self.bn(self.conv(x), use_running_average=not mutable)

In [12]:
class EfficientNetV2(nn.Module):
    dtype : Any
    architecture_configs : list 
    width_mult : float = 1.
    dropout_rate : float = 0.
    stem_size : int = 24
    dropblock_rate : float = .0

    def setup(self):
        # building first layer
        input_channel = self.stem_size
        # This is the conv_stem
        self.conv_stem = ConvBlock(
            output_channels=input_channel,
            kernel_size=3,
            stride=2,
            dtype=self.dtype,
            dropout_rate=self.dropout_rate,
            use_activation=False
        )
        # Main computational part
        total_layers = sum([n for (t, c, n, s, use_squeeze_excitation) in self.architecture_configs[1:]])
        layer_count = 1
        layers = []
        for j, (t, c, n, s, use_squeeze_excitation) in enumerate(self.architecture_configs):
            output_channel = _make_divisible(c * self.width_mult, 8)
            block_layers = []
            for i in range(n):
                # The first one should be simply ConvBlock
                if j == 0:
                    block_layers.append(ConvBlock(
                        output_channels=input_channel,
                        kernel_size=3,
                        stride=1,
                        dtype=self.dtype,
                        dropout_rate=self.dropout_rate,
                        apply_skip_connection=True
                    ))
                else:
                    layer_count += 1
                    # Progressive dropblock of MBConv layers
                    block_layers.append(MBConv(
                        input_channels=input_channel,
                        output_channels=output_channel,
                        stride=s if i == 0 else 1,
                        expand_ratio=t,
                        use_squeeze_excitation=use_squeeze_excitation,
                        dtype=self.dtype,
                        dropout_rate=self.dropout_rate,
                        dropblock_rate=self.dropblock_rate * layer_count / (total_layers + 1)
                    ))
                input_channel = output_channel
            layers.append(nn.Sequential(block_layers))
        self.features = nn.Sequential(layers)
        # building last several layers
        self.output_channel = _make_divisible(1792 * self.width_mult, 8) if self.width_mult > 1.0 else 1280
        self.conv_head = ConvBlock(
            output_channels=self.output_channel,
            kernel_size=1,
            dtype=self.dtype,
            dropout_rate=self.dropout_rate
        )

    def __call__(self, x):
        return self.conv_head(self.features(self.conv_stem(x)))

In [13]:
random_key = jax.random.PRNGKey(config.seed)
random_key, efficientnet_key = jax.random.split(random_key)

backbone_model = EfficientNetV2(
    architecture_configs=config.efficientnet_arch_configs,
    dtype=jnp.bfloat16,
    dropout_rate=0.0,
    dropblock_rate=0.0,
    stem_size=config.efficientnet_stem_size
)

dummy_input = jnp.ones((1, config.image_size, config.image_size, 3), dtype=jnp.bfloat16)
backbone_params = backbone_model.init(
    {'dropout': efficientnet_key, 'params': efficientnet_key}, dummy_input
)

with open(config.efficientnet_v2_path, "rb") as f:
    byte_data = f.read()
backbone_params = flax.serialization.from_bytes(backbone_params, byte_data)

backbone_params = jax.tree_map(lambda x : x.astype(jnp.bfloat16), backbone_params)
backbone_params = flax.core.unfreeze(backbone_params)

In [14]:
class EfficientNetV2ClassificationModel(nn.Module):
    backbone: Any
    num_features : int
    dtype : Any
    
    @nn.compact 
    def __call__(self, x):
        mutable = self.is_mutable_collection('batch_stats')
        x = self.backbone(x)
        x = jax.nn.swish(jnp.mean(x, axis=(1, 2)))
        return nn.Dense(
            self.num_features,
            use_bias=False,
            param_dtype=self.dtype,
            dtype=self.dtype
        )(x)

In [15]:
random_key, classifier_key_1, classifier_key_2 = jax.random.split(random_key, 3)

classification_model = EfficientNetV2ClassificationModel(
    backbone=backbone_model, num_features=config.num_classes, dtype=jnp.bfloat16
)
params = classification_model.init({'params': classifier_key_1, 'dropout': classifier_key_2}, dummy_input)
params = flax.core.unfreeze(params)

params['params']['backbone'] = backbone_params['params']
params['batch_stats']['backbone'] = backbone_params['batch_stats']

params = jax.tree_map(lambda x : x.astype(jnp.bfloat16), params)

In [16]:
scheduler = optax.constant_schedule(config.learning_rate)
optimizer = optax.chain(
    optax.clip(1.0),
    optax.adamw(
        learning_rate=scheduler,
        weight_decay=config.weight_decay
    )
)

train_state = {
    'model': params,
    'op': optimizer.init(params['params'])
}
train_state = flax.jax_utils.replicate(train_state)

## Training

In [17]:
def get_accuracy(logits, labels):
    return (logits.argmax(-1) == labels.argmax(-1)).astype(jnp.float32).mean()


def get_cross_entropy(logits, labels):
    return -(jax.nn.log_softmax(logits * 16) * labels).sum(-1).mean()

In [18]:
def train_step(apply_fn, update_fn, train_state, batch, key):
    
    def compute_loss(params):
        logits, mutable_states = apply_fn(
            params,
            batch['image'],
            mutable='batch_stats',
            rngs={'dropout': key}
        )
        
        loss = get_cross_entropy(logits, batch['label'])
        loss = lax.pmean(loss, 'devices')
        
        params['batch_stats'] = mutable_states['batch_stats']
        params['batch_stats'] = jax.tree_map(
            partial(lax.pmean, axis_name='devices'),
            params['batch_stats']
        )
        
        return loss, (logits, params)
    
    gradient_fn = jax.value_and_grad(compute_loss, has_aux=True)
    (loss, (logits, train_state['model'])), gradients = gradient_fn(train_state['model'])
    gradients = lax.pmean(gradients['params'], 'devices')
    
    updates, train_state['op'] = update_fn(gradients, train_state['op'], train_state['model']['params'])
    train_state['model']['params'] = optax.apply_updates(train_state['model']['params'], updates)
    
    accuracy = get_accuracy(logits, batch['label'])
    accuracy = lax.pmean(accuracy, 'devices')
    
    return {'train_state': train_state, 'loss': loss, 'accuracy': accuracy}


parallelized_train_step = jax.pmap(
    partial(
        train_step,
        apply_fn=classification_model.apply,
        update_fn=optimizer.update
    ),
    axis_name='devices'
)

In [19]:
def val_step(apply_fn, params, batch):
    logits = apply_fn(params, batch['image'], mutable=False)
    
    loss = get_cross_entropy(logits, batch['label'])
    loss = lax.pmean(loss, 'devices')
    
    accuracy = get_accuracy(logits, batch['label'])
    accuracy = lax.pmean(accuracy, 'devices')
    
    return {'loss': loss, 'accuracy': accuracy}


parallelized_val_step = jax.pmap(
    partial(val_step, apply_fn=classification_model.apply),
    axis_name='devices'
)

In [20]:
def define_dataloaders():
    train_loader = create_dataloader(
        TRAIN_FILES,
        is_labeled=True,
        is_ordered=False,
        shuffle_buffer_size=4 * config.batch_size,
        drop_remainder=True
    )
    val_loader = create_dataloader(
        VAL_FILES,
        is_labeled=True,
        is_ordered=True,
        shuffle_buffer_size=4 * config.batch_size,
        drop_remainder=True
    )
    return train_loader, val_loader

In [21]:
for epoch in range(1, config.epochs + 1):
    train_loader, val_loader = define_dataloaders()
    train_state['model'] = flax.core.unfreeze(train_state['model'])
    train_epoch_loss, train_epoch_accuracy, train_counter = 0, 0, 0
    val_epoch_loss, val_epoch_accuracy, val_counter = 0, 0, 0
    
    # Training
    train_progress_bar = tqdm(
        train_loader,
        total=199,
        leave=False,
        desc=f"Train {epoch}/{config.epochs}"
    )
    for batch in train_progress_bar:
        random_key, train_key = jax.random.split(random_key)
        output = parallelized_train_step(
            train_state=train_state,
            batch=batch,
            key=common_utils.shard_prng_key(train_key)
        )
        train_state = output['train_state']
        train_loss = output['loss'][0].item()
        train_accuracy = output['accuracy'][0].item()
        train_progress_bar.set_postfix({
            'epoch': epoch,
            'train-loss': train_loss,
            'train-accuracy': train_accuracy
        })
        train_epoch_loss += train_loss
        train_epoch_accuracy += train_accuracy
        train_counter += 1
    train_state['model'] = flax.core.freeze(train_state['model'])
    train_epoch_loss = train_epoch_loss / train_counter
    train_epoch_accuracy = train_epoch_accuracy / train_counter
    
    # Validation
    val_progress_bar = tqdm(
        val_loader,
        total=58,
        leave=False,
        desc=f"Validation {epoch}/{config.epochs}"
    )
    for batch in val_progress_bar:
        output = parallelized_val_step(params=train_state['model'], batch=batch)
        val_loss = output['loss'][0].item()
        val_accuracy = output['accuracy'][0].item()
        val_progress_bar.set_postfix({
            'epoch': epoch,
            'val-loss': val_loss,
            'val-accuracy': val_accuracy
        })
        val_epoch_loss += val_loss
        val_epoch_accuracy += val_accuracy
        val_counter += 1
    val_epoch_loss = val_epoch_loss / train_counter
    val_epoch_accuracy = val_epoch_accuracy / train_counter
    
    wandb.log({
        "train-loss": train_epoch_loss,
        "val-loss": val_epoch_loss,
        "train-accuracy": train_epoch_accuracy,
        "val-accuracy": val_epoch_accuracy,
    })
    
    # Save Checkpoint
    with open("checkpoint.msgpack", "wb") as outfile:
        state_dict = flax.serialization.to_state_dict(train_state)
        serialized_state_dict = flax.serialization.msgpack_serialize(state_dict)
        outfile.write(serialized_state_dict)
    
    # Upload Checkpoint as Weights & Biases Artifacts
    artifact = wandb.Artifact(f'checkpoint-run-{wandb.run.id}', type='train-state')
    artifact.add_file("checkpoint.msgpack")
    wandb.log_artifact(artifact, aliases=["latest", f"epoch-{epoch}"])

                                                                                                              

In [22]:
wandb.finish()

0,1
train-accuracy,▁▇███
train-loss,█▁▁▁▁
val-accuracy,▁▆▇██
val-loss,█▃▂▁▁

0,1
train-accuracy,1.0
train-loss,0.00176
val-accuracy,0.26665
val-loss,0.13203
