In [1]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [2]:
import jax
jax.local_devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

# For CV

In [14]:
from transformers import FlaxResNetModel, AutoImageProcessor
from PIL import Image
import requests
from flax.training import train_state
import optax

In [4]:
num_classes = 10
seed = 0
model = FlaxResNetModel.from_pretrained('microsoft/resnet-50')

Downloading (…)lve/main/config.json:   0%|          | 0.00/69.6k [00:00<?, ?B/s]

Downloading flax_model.msgpack:   0%|          | 0.00/102M [00:00<?, ?B/s]

In [10]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
inputs = image_processor(images=image, return_tensors="np")

Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


In [15]:
state = train_state.TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=optax.adam(1e-3),
)

In [16]:
jax.tree_map(lambda x: x.shape, model.params)

{'batch_stats': {'embedder': {'embedder': {'normalization': {'mean': (64,),
     'var': (64,)}}},
  'encoder': {'stages': {'0': {'layers': {'0': {'layer': {'0': {'normalization': {'mean': (64,),
          'var': (64,)}},
        '1': {'normalization': {'mean': (64,), 'var': (64,)}},
        '2': {'normalization': {'mean': (256,), 'var': (256,)}}},
       'shortcut': {'normalization': {'mean': (256,), 'var': (256,)}}},
      '1': {'layer': {'0': {'normalization': {'mean': (64,), 'var': (64,)}},
        '1': {'normalization': {'mean': (64,), 'var': (64,)}},
        '2': {'normalization': {'mean': (256,), 'var': (256,)}}}},
      '2': {'layer': {'0': {'normalization': {'mean': (64,), 'var': (64,)}},
        '1': {'normalization': {'mean': (64,), 'var': (64,)}},
        '2': {'normalization': {'mean': (256,), 'var': (256,)}}}}}},
    '1': {'layers': {'0': {'layer': {'0': {'normalization': {'mean': (128,),
          'var': (128,)}},
        '1': {'normalization': {'mean': (128,), 'var': (12

# For NLP

In [None]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]

In [None]:
task = "cola"
model_checkpoint = "bert-base-cased"
per_device_batch_size = 4

In [None]:
from datasets import load_dataset, load_metric

In [None]:
actual_task = "mnli" if task == "mnli-mm" else task
is_regression = task == "stsb"

raw_dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.9k [00:00<?, ?B/s]

Downloading and preparing dataset glue/cola to /root/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data:   0%|          | 0.00/377k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8551 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1043 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1063 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

  metric = load_metric('glue', actual_task)


Downloading builder script:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [None]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

In [None]:
sentence1_key, sentence2_key = task_to_keys[task]

def preprocess_function(examples):
    texts = (
        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
    )
    processed = tokenizer(*texts, padding="max_length", max_length=128, truncation=True)

    processed["labels"] = examples["label"]
    return processed

In [None]:
tokenized_dataset = raw_dataset.map(preprocess_function, batched=True, remove_columns=raw_dataset["train"].column_names)

Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [None]:
train_dataset = tokenized_dataset["train"]
eval_dataset = tokenized_dataset["validation"]

In [None]:
print('train_dataset:', len(train_dataset))
print('eval_dataset:', len(eval_dataset))

train_dataset: 8551
eval_dataset: 1043


## Fine tuning the model

In [None]:
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig

num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
seed = 0

config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed)

Downloading flax_model.msgpack:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of FlaxBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased 

In [None]:
import flax
import jax
import optax

from itertools import chain
from tqdm.notebook import tqdm
from typing import Callable

import jax.numpy as jnp

from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from flax.training import train_state

In [None]:
import flax.traverse_util as traverse_util

In [None]:
total_batch_size = per_device_batch_size * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

The overall batch size (both for training and eval) is 4


### Learning rate scheduler

In [None]:
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs

learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)

### TrainState

In [None]:
class TrainState(train_state.TrainState):
    logits_function: Callable = flax.struct.field(pytree_node=False)
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [None]:
def decay_mask_fn(params):
    flat_params = traverse_util.flatten_dict(params)
    flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
    return traverse_util.unflatten_dict(flat_mask)

In [None]:
def adamw(weight_decay):
    return optax.adamw(learning_rate=learning_rate_function, b1=0.9, b2=0.999, eps=1e-6, weight_decay=weight_decay, mask=decay_mask_fn)

In [None]:
def loss_function(logits, labels):
  if is_regression:
    return jnp.mean((logits[..., 0] - labels) ** 2)

  xentropy = optax.softmax_cross_entropy(logits, onehot(labels, num_classes=num_labels))
  return jnp.mean(xentropy)

def eval_function(logits):
    return logits[..., 0] if is_regression else logits.argmax(-1)

In [None]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=adamw(0.0),
    logits_function=eval_function,
    loss_function=loss_function,
)

In [None]:
jax.tree_map(lambda x: x.shape, model.params)

{'bert': {'embeddings': {'LayerNorm': {'bias': (768,), 'scale': (768,)},
   'position_embeddings': {'embedding': (512, 768)},
   'token_type_embeddings': {'embedding': (2, 768)},
   'word_embeddings': {'embedding': (28996, 768)}},
  'encoder': {'layer': {'0': {'attention': {'output': {'LayerNorm': {'bias': (768,),
        'scale': (768,)},
       'dense': {'bias': (768,), 'kernel': (768, 768)}},
      'self': {'key': {'bias': (768,), 'kernel': (768, 768)},
       'query': {'bias': (768,), 'kernel': (768, 768)},
       'value': {'bias': (768,), 'kernel': (768, 768)}}},
     'intermediate': {'dense': {'bias': (3072,), 'kernel': (768, 3072)}},
     'output': {'LayerNorm': {'bias': (768,), 'scale': (768,)},
      'dense': {'bias': (768,), 'kernel': (3072, 768)}}},
    '1': {'attention': {'output': {'LayerNorm': {'bias': (768,),
        'scale': (768,)},
       'dense': {'bias': (768,), 'kernel': (768, 768)}},
      'self': {'key': {'bias': (768,), 'kernel': (768, 768)},
       'query': {'b

## Training loop

In [None]:
def train_step(state, batch, dropout_rng):
    targets = batch.pop("labels")
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)

    def loss_function(params):
        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        loss = state.loss_function(logits, targets)
        return loss

    grad_function = jax.value_and_grad(loss_function)
    loss, grad = grad_function(state.params)
    grad = jax.lax.pmean(grad, "batch")
    new_state = state.apply_gradients(grads=grad)
    metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_function(state.step)}, axis_name="batch")
    return new_state, metrics, new_dropout_rng

In [None]:
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))

In [None]:
def eval_step(state, batch):
    logits = state.apply_fn(**batch, params=state.params, train=False)[0]
    return state.logits_function(logits)

In [None]:
parallel_eval_step = jax.pmap(eval_step, axis_name="batch")

## Defininig the data loaders

In [None]:
def glue_train_data_loader(rng, dataset, batch_size):
    steps_per_epoch = len(dataset) // batch_size
    perms = jax.random.permutation(rng, len(dataset))
    perms = perms[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    perms = perms.reshape((steps_per_epoch, batch_size))

    for perm in perms:
        batch = dataset[perm]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
def glue_eval_data_loader(dataset, batch_size):
    for i in range(len(dataset) // batch_size):
        batch = dataset[i * batch_size : (i + 1) * batch_size]
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)

        yield batch

In [None]:
state = flax.jax_utils.replicate(state)

## Training

In [None]:
rng = jax.random.PRNGKey(seed)
dropout_rngs = jax.random.split(rng, jax.local_device_count())