# Using Transformers from Huggingface
This is an example notebook of how to use Huggingface models with ZnNL

In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import znnl as nl

import numpy as np
import optax

from znnl.models import HuggingFaceFlaxModel
from transformers import ResNetConfig, FlaxResNetForImageClassification
import jax
print(jax.default_backend())

In [None]:
data_generator = nl.data.CIFAR10Generator(2)

# Input data needs to have shape (num_points, channels, height, width)
train_ds={"inputs": np.swapaxes(data_generator.train_ds["inputs"], 1, 3), "targets": data_generator.train_ds["targets"]}
test_ds={"inputs": np.swapaxes(data_generator.test_ds["inputs"], 1, 3), "targets": data_generator.test_ds["targets"]}

data_generator.train_ds = train_ds
data_generator.test_ds = test_ds

# Execute

## Two default architectures

```python 
# Two standard ResNet architectures

resnet50_config = ResNetConfig(
    num_channels = 3,
    embedding_size = 64, 
    hidden_sizes = [256, 512, 1024, 2048], 
    depths = [3, 4, 6, 3], 
    layer_type = 'bottleneck', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    out_features = None, 
    out_indices = None, 
    id2label = {i: i for i in range(10)}, # Dummy labels to define the output dimension
    return_dict = True,
)

# ResNet-18 taken from https://huggingface.co/microsoft/resnet-18/blob/main/config.json
resnet18_config = ResNetConfig(
    num_channels = 3,
    embedding_size = 64, 
    hidden_sizes = [64, 128, 256, 512], 
    depths = [2, 2, 2, 2], 
    layer_type = 'basic', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    id2label = {i: i for i in range(10)}, # Dummy labels to define the output dimension
    return_dict = True,
)
```

## Some example execution code

In [None]:
# From scratch

resnet_config = ResNetConfig(
    num_channels = 3,
    embedding_size = 12, 
    hidden_sizes = [6, 6, 6], 
    depths = [2, 2, 2], 
    layer_type = 'basic', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    out_features = None, 
    out_indices = None, 
    id2label = {i: i for i in range(10)}, # Dummy labels to define the output dimension
    return_dict = True,
)


_model = FlaxResNetForImageClassification(
    config=resnet_config,
    input_shape=(1, 32, 32, 3),
    seed=0,
    _do_init = True,
)

model = HuggingFaceFlaxModel(
    _model, 
    optax.adam(learning_rate=1e-3),
    store_on_device=False,
    batch_size=2,
)

In [None]:
train_recorder = nl.training_recording.JaxRecorder(
    name="train_recorder",
    loss=True,
    accuracy=True,
    ntk=True,
    covariance_entropy=True,
    magnitude_variance=True, 
    trace=True,
    loss_derivative=True,
    update_rate=1, 
    chunk_size=1000,
)
train_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)

trainer = nl.training_strategies.SimpleTraining(
    model=model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
    recorders=[train_recorder],
)

In [None]:
batch_wise_training_metrics = trainer.train_model(
    train_ds=data_generator.train_ds,
    test_ds=data_generator.test_ds,
    batch_size=2,
    epochs=10,
)

### Check the results

In [None]:
import matplotlib.pyplot as plt

In [None]:
train_report = train_recorder.gather_recording()
num_params = jax.flatten_util.ravel_pytree(model.model_state.params)[0].shape

In [None]:
plt.plot(batch_wise_training_metrics['train_losses'], label='train loss')
plt.plot(train_report.covariance_entropy, label="covariance_entropy")
plt.plot(train_report.trace/num_params, label="trace")
plt.yscale("log")
plt.legend()
plt.show()

# Train Flag 

Using the kwarg `train = True` is used in the forward pass to train the model to data. 
How this differs from the setting the kwarg to `False` can be found in:
https://flax.readthedocs.io/en/latest/_modules/flax/linen/normalization.html#BatchNorm

I don't understand the difference being made in a forward pass when the model is being trained or not.
However, the difference is clearly visible when training a model in the following.

In [None]:
from jax import random
from znnl.training_strategies import SimpleTraining
from znnl.loss_functions import CrossEntropyLoss
from transformers import FlaxResNetForImageClassification, ResNetConfig
from znnl.models import HuggingFaceFlaxModel
import optax
import jax.numpy as np
from znnl.training_recording import JaxRecorder

In [None]:
resnet_config = ResNetConfig(
    num_channels = 2,
    embedding_size = 24, 
    hidden_sizes = [12, 12, 12], 
    depths = [2, 2, 2], 
    layer_type = 'basic', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    out_features = None, 
    out_indices = None, 
    id2label = {i: i for i in range(3)}, # Dummy labels to define the output dimension
    return_dict = True,
)

_resnet = FlaxResNetForImageClassification(
    config=resnet_config,
    input_shape=(1, 8, 8, 2),
    seed=0,
    _do_init=True,
)

resnet = HuggingFaceFlaxModel(
    _resnet,
    optax.sgd(learning_rate=1e-4),
    batch_size=3,
)


key = random.PRNGKey(0)
train_ds = {
    "inputs": random.normal(key, (30, 2, 8, 8)),
    "targets": np.repeat(np.eye(3), 10, axis=0),
}

train_recorder = JaxRecorder(
    name="train_recorder",
    loss=True,
    update_rate=1, 
    chunk_size=1e5
)
train_recorder.instantiate_recorder(
    data_set=train_ds
)


trainer = SimpleTraining(
    model=resnet,
    loss_fn=CrossEntropyLoss(),
    recorders=[train_recorder],
)

In [None]:
# To run the study just uncomment the following block

# batched_loss = trainer.train_model(
#     train_ds=train_ds,
#     test_ds=train_ds,
#     epochs=200,
#     batch_size=30,
# )

The following code example shows the difference in the forward pass using `train=True` and `train=False`.

The function that is used to evaluate the model in the recorder is defined with `train=False`. 
The forward pass evaluated in the training itself is defined with `train=True`. 

In [None]:
import matplotlib.pyplot as plt

train_report = train_recorder.gather_recording()

plt.plot(train_report.loss, label="loss using train=False")
plt.yscale("log")
plt.plot(batched_loss['train_losses'], label="loss using train=True")
plt.yscale("log")
plt.title("Train Losses")
plt.legend()