# Using Transformers from Huggingface
This is an example notebook of how to use Huggingface models with ZnNL

In [None]:
# import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import znnl as nl

import numpy as np
import optax

from znnl.models import HuggingFaceFlaxModel
from transformers import ResNetConfig, FlaxResNetForImageClassification
import jax
print(jax.default_backend())

In [None]:
data_generator = nl.data.CIFAR10Generator(2)

# Input data needs to have shape (num_points, channels, height, width)
train_ds={"inputs": np.swapaxes(data_generator.train_ds["inputs"], 1, 3), "targets": data_generator.train_ds["targets"]}
test_ds={"inputs": np.swapaxes(data_generator.test_ds["inputs"], 1, 3), "targets": data_generator.test_ds["targets"]}

data_generator.train_ds = train_ds
data_generator.test_ds = test_ds

# Execute

In [None]:
# From scratch

resnet_config = ResNetConfig(
    num_channels = 3,
    embedding_size = 24, 
    hidden_sizes = [12, 12, 12], 
    depths = [3, 4, 6], 
    layer_type = 'bottleneck', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    out_features = None, 
    out_indices = None, 
    id2label = dict(zip(np.arange(10), np.arange(10))),
    return_dict = True,
)


model = FlaxResNetForImageClassification(
    config=resnet_config,
    input_shape=(1, 32, 32, 3),
    seed=0,
    _do_init = True,
)

znnl_model = HuggingFaceFlaxModel(
    model, 
    optax.adamw(learning_rate=0.001),
)

In [None]:
train_recorder = nl.training_recording.JaxRecorder(
    name="train_recorder",
    loss=True,
    ntk=True,
    covariance_entropy=True,
    magnitude_variance=True, 
    trace=True,
    loss_derivative=True,
    update_rate=1
)
train_recorder.instantiate_recorder(
    data_set=data_generator.train_ds
)

trainer = nl.training_strategies.SimpleTraining(
    model=znnl_model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
    recorders=[train_recorder],
)

In [None]:
batch_wise_training_metrics = trainer.train_model(
    train_ds=data_generator.train_ds,
    test_ds=data_generator.test_ds,
    batch_size=100,
    epochs=50,
)

In [None]:
train_report = train_recorder.gather_recording()

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(train_report.loss, label="loss")
plt.plot(train_report.covariance_entropy, label="covariance_entropy")
plt.plot(train_report.trace/5000, label="trace")
plt.yscale("log")
plt.legend()
plt.show()