# Using Transformers from Huggingface
This is an example notebook of how to use Huggingface models with ZnNL

In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import znnl as nl

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax
from transformers import ResNetConfig, FlaxResNetForImageClassification

from flax.training import

import jax
print(jax.default_backend())

cpu


In [31]:
data_generator = nl.data.CIFAR10Generator(10)

In [32]:
data_generator.__class__

znnl.data.cifar10.CIFAR10Generator

In [36]:
id2label = dict(zip(np.arange(10).tolist(), np.arange(20).tolist()))

In [38]:
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
from typing import Callable, Sequence, Union

import jax
import jax.numpy as np
import jax.random
import neural_tangents as nt
import optax
from flax.training.train_state import TrainState

from znnl.optimizers.trace_optimizer import TraceOptimizer
from znnl.utils.prng import PRNGKey


class HFBaseModel:
    """
    Base class for huggingface models.
    """

    def __init__(
        self,
        model: Callable,
        optimizer: Union[Callable, TraceOptimizer],
        input_shape: tuple,
        ntk_batch_size: int = 10,
        trace_axes: Union[int, Sequence[int]] = (-1,),
    ):
        """
        Construct a znrnd model.
        Parameters
        ----------
        optimizer : Callable
                optimizer to use in the training. OpTax is used by default and
                cross-compatibility is not assured.
        input_shape : tuple
                Shape of the NN input.
        seed : int, default None
                Random seed for the RNG. Uses a random int if not specified.
        ntk_batch_size : int, default 10
                Batch size to use in the NTK computation.
        trace_axes : Union[int, Sequence[int]]
                Tracing over axes of the NTK.
                The default value is trace_axes(-1,), which reduces the NTK to a tensor
                of rank 2.
                For a full NTK set trace_axes=().
        """
        self.apply_fn = model.__call__
        self.params = model.params
        
        self.optimizer = optimizer
        self.input_shape = input_shape

        # Initialized in self.init_model
        self.rng = None

        # initialize the model state
        self.model_state = self._create_train_state()

        # Prepare NTK calculation
        self.empirical_ntk = nt.batch(
            nt.empirical_ntk_fn(f=self._ntk_apply_fn, trace_axes=trace_axes),
            batch_size=ntk_batch_size,
        )
        self.empirical_ntk_jit = jax.jit(self.empirical_ntk)

    def _create_train_state(self) -> TrainState:
        """
        Create a training state of the model.
        Returns
        -------
        initial state of model to then be trained.
        Notes
        -----
        TODO: Make the TrainState class passable by the user as it can track custom
              model properties.
        """
        # Set dummy optimizer for case of trace optimizer.
        if isinstance(self.optimizer, TraceOptimizer):
            optimizer = optax.sgd(1.0)
        else:
            optimizer = self.optimizer

        return TrainState.create(apply_fn=self.apply_fn, params=self.params, tx=optimizer)

    def _ntk_apply_fn(self, params: dict, inputs: np.ndarray):
        """
        Apply function used in the NTK computation.
        Parameters
        ----------
        params: dict
                Contains the model parameters to use for the model computation.
        inputs : np.ndarray
                Feature vector on which to apply the model.
        Returns
        -------
        The apply function used in the NTK computation.
        """
        raise NotImplementedError("Implemented in child class")

    def compute_ntk(
        self,
        x_i: np.ndarray,
        x_j: np.ndarray = None,
        infinite: bool = False,
    ):
        """
        Compute the NTK matrix for the model.
        Parameters
        ----------
        x_i : np.ndarray
                Dataset for which to compute the NTK matrix.
        x_j : np.ndarray (optional)
                Dataset for which to compute the NTK matrix.
        infinite : bool (default = False)
                If true, compute the infinite width limit as well.
        Returns
        -------
        NTK : dict
                The NTK matrix for both the empirical and infinite width computation.
        """
        if x_j is None:
            x_j = x_i
        empirical_ntk = self.empirical_ntk_jit(x_i, x_j, self.model_state.params)

        if infinite:
            try:
                infinite_ntk = self.kernel_fn(x_i, x_j, "ntk")
            except AttributeError:
                raise NotImplementedError("Infinite NTK not available for this model.")
        else:
            infinite_ntk = None

        return {"empirical": empirical_ntk, "infinite": infinite_ntk}
    
    def _apply_fn(self, feature_vector: np.ndarray):
        """
        Apply the model.
        Parameters
        ----------
        feature_vector : np.ndarray
                Feature vector on which to apply operation.
        Returns
        -------
        output of the model.
        """
        raise 

    def __call__(self, feature_vector: np.ndarray):
        """
        Call the network.
        Parameters
        ----------
        feature_vector : np.ndarray
                Feature vector on which to apply operation.
        Returns
        -------
        output of the model.
        """
        return self.apply(self.model_state.params, feature_vector)


In [39]:
"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
import logging
from typing import Callable, List, Sequence, Union

import jax
import jax.numpy as np
from flax import linen as nn

from znnl.models.jax_model import JaxModel

logger = logging.getLogger(__name__)


class HFModel(HFBaseModel):
    """
    Class for the Flax model in ZnRND.
    """

    def __init__(
        self,
        model: Callable,
        optimizer: Callable,
        input_shape: tuple,
        batch_size: int = 10,
        trace_axes: Union[int, Sequence[int]] = (-1,),
    ):
        """
        Construct a Flax model.

        Parameters
        ----------
        layer_stack : List[nn.Module]
                A list of flax modules to be used in the call method.
        optimizer : Callable
                optimizer to use in the training. OpTax is used by default and
                cross-compatibility is not assured.
        input_shape : tuple
                Shape of the NN input.
        batch_size : int
                Size of batch to use in the NTk calculation.
        flax_module : nn.Module
                Flax module to use instead of building one from scratch here.
        trace_axes : Union[int, Sequence[int]]
                Tracing over axes of the NTK.
                The default value is trace_axes(-1,), which reduces the NTK to a tensor
                of rank 2.
                For a full NTK set trace_axes=().
        seed : int, default None
                Random seed for the RNG. Uses a random int if not specified.
        """
        logger.info(
            "Flax models have occasionally experienced memory allocation issues on "
            "GPU. This is an ongoing bug that we are striving to fix soon."
        )

        self.apply_fn = jax.jit(model.__call__)

        # Save input parameters, call self.init_model
        super().__init__(
            model=model,
            optimizer=optimizer,
            input_shape=input_shape,
            trace_axes=trace_axes,
            ntk_batch_size=batch_size,
        )

    def _ntk_apply_fn(self, params, inputs: np.ndarray):
        """
        Return an NTK capable apply function.

        Parameters
        ----------
        params : dict
                Network parameters to use in the calculation.
        inputs : np.ndarray
                Data on which to apply the network

        Returns
        -------
        Acts on the data with the model architecture and parameter set.
        """
        return self.model_state.apply_fn({"params": params}, inputs, mutable=["batch_stats"])[0]


    def apply(self, params: dict, inputs: np.ndarray):
        """Apply the model to a feature vector.

        Parameters
        ----------
        params: dict
                Contains the model parameters to use for the model computation.
        inputs : np.ndarray
                Feature vector on which to apply the model.

        Returns
        -------
        Output of the model.
        """
        return self.model_state.apply_fn(inputs, params=params).logits


In [None]:
# Write HF model as Jax model

"""
ZnNL: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: zincwarecode@gmail.com
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
"""
import logging
from typing import Callable, List, Sequence, Union

import jax
import jax.numpy as np
from flax import linen as nn

from znnl.models.jax_model import JaxModel

logger = logging.getLogger(__name__)



class HFFlaxModel(JaxModel):
    """
    Class for the Flax model in ZnRND.
    """

    def __init__(
        self,
        optimizer: Callable,
        input_shape: tuple,
        batch_size: int = 10,
        layer_stack: List[nn.Module] = None,
        flax_module: nn.Module = None,
        trace_axes: Union[int, Sequence[int]] = (-1,),
        seed: int = None,
    ):
        """
        Construct a Flax model.

        Parameters
        ----------
        layer_stack : List[nn.Module]
                A list of flax modules to be used in the call method.
        optimizer : Callable
                optimizer to use in the training. OpTax is used by default and
                cross-compatibility is not assured.
        input_shape : tuple
                Shape of the NN input.
        batch_size : int
                Size of batch to use in the NTk calculation.
        flax_module : nn.Module
                Flax module to use instead of building one from scratch here.
        trace_axes : Union[int, Sequence[int]]
                Tracing over axes of the NTK.
                The default value is trace_axes(-1,), which reduces the NTK to a tensor
                of rank 2.
                For a full NTK set trace_axes=().
        seed : int, default None
                Random seed for the RNG. Uses a random int if not specified.
        """
        logger.info(
            "Flax models have occasionally experienced memory allocation issues on "
            "GPU. This is an ongoing bug that we are striving to fix soon."
        )
        if layer_stack is not None:
            self.model = FundamentalModel(layer_stack)
        if flax_module is not None:
            self.model = flax_module
        if layer_stack is None and flax_module is None:
            raise TypeError("Provide either a Flax nn.Module or a layer stack.")

        self.apply_fn = jax.jit(self.model.apply)

        # Save input parameters, call self.init_model
        super().__init__(
            optimizer=optimizer,
            input_shape=input_shape,
            seed=seed,
            trace_axes=trace_axes,
            ntk_batch_size=batch_size,
        )

    def _ntk_apply_fn(self, params, inputs: np.ndarray):
        """
        Return an NTK capable apply function.

        Parameters
        ----------
        params : dict
                Network parameters to use in the calculation.
        inputs : np.ndarray
                Data on which to apply the network

        Returns
        -------
        Acts on the data with the model architecture and parameter set.
        """
        return self.model.apply({"params": params}, inputs, mutable=["batch_stats"])[0]

    def _init_params(self, kernel_init: Callable = None, bias_init: Callable = None):
        """Initialize a state for the model parameters.

        Parameters
        ----------
        kernel_init : Callable
                Define the kernel initialization.
        bias_init : Callable
                Define the bias initialization.

        Returns
        -------
        Initial state for the model parameters.
        """
        if kernel_init:
            self.model.kernel_init = kernel_init
        if bias_init:
            self.model.bias_init = bias_init

        params = self.model.init(self.rng(), np.ones(list(self.input_shape)))["params"]

        return params

    def apply(self, params: dict, inputs: np.ndarray):
        """Apply the model to a feature vector.

        Parameters
        ----------
        params: dict
                Contains the model parameters to use for the model computation.
        inputs : np.ndarray
                Feature vector on which to apply the model.

        Returns
        -------
        Output of the model.
        """
        return self.apply_fn({"params": params}, inputs)


In [40]:
# From scratch

resnet_config = ResNetConfig(
    num_channels = 3,
    embedding_size = 64, 
    hidden_sizes = [256, 512, 1024, 2048], 
    depths = [3, 4, 6, 3], 
    layer_type = 'bottleneck', 
    hidden_act = 'relu', 
    downsample_in_first_stage = False, 
    out_features = None, 
    out_indices = None, 
    id2label = id2label,
)


model = FlaxResNetForImageClassification(
    config=resnet_config,
    input_shape=(1, 32, 32, 3),
    seed=0,
    _do_init = True,
)

znnl_model = HFModel(
    model, 
    optax.adam(learning_rate=0.01),
    input_shape=(1, 32, 32, 3),   
)

In [41]:
test_input = np.swapaxes(data_generator.train_ds['inputs'][:1], 1, 3)

znnl_model(test_input)

Array([[ 0.05988872, -0.20056242, -4.595742  , -1.6283834 ,  4.054898  ,
        -0.93490016,  0.42661703,  1.5616785 ,  0.32605764,  1.8218877 ]],      dtype=float32)

In [42]:
from znnl.loss_functions.cross_entropy_loss import CrossEntropyDistance
import optax

def loss_fn(prediction, target): return optax.softmax_cross_entropy(logits=prediction, labels=target)


def vmapped_cross_entropy_loss(inputs, targets):
    mapped_loss = jax.vmap(loss_fn, in_axes=(0, 0))(inputs, targets)
    return mapped_loss.mean()

In [47]:
def cross_entropy_loss(logits, labels):
    return -np.mean(np.sum(labels * logits, axis=-1))

trainer = nl.training_strategies.SimpleTraining(
    model=znnl_model, 
    # loss_fn=vmapped_cross_entropy_loss,
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    # loss_fn=nl.loss_functions.MeanPowerLoss(order=2),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
)

In [48]:
batch_wise_training_metrics = trainer.train_model(
    train_ds={"inputs": np.swapaxes(data_generator.train_ds["inputs"], 1, 3), "targets": data_generator.train_ds["targets"]},
    test_ds={"inputs": np.swapaxes(data_generator.test_ds["inputs"], 1, 3), "targets": data_generator.test_ds["targets"]},
    batch_size=100,
    epochs=2,
)

  0%|                                                                      | 0/2 [00:00<?, ?batch/s]

Epoch: 2: 100%|██████████████████████████████████████| 2/2 [00:06<00:00,  3.45s/batch, accuracy=0.1]


## Standard Code

In [None]:
class ProductionModule(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=300)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)

        return x

In [None]:
production_model = nl.models.FlaxModel(
            flax_module=ProductionModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )

training_strategy = nl.training_strategies.SimpleTraining(
    model=production_model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
)

In [None]:
batch_wise_training_metrics = training_strategy.train_model(
    train_ds={"inputs": data_generator.train_ds["inputs"], "targets": data_generator.train_ds["targets"]},
    test_ds={"inputs": data_generator.test_ds["inputs"], "targets": data_generator.test_ds["targets"]},
    batch_size=100,
)

## Random Data Selection

In [None]:
class RNDModule(nn.Module):
    """
    Simple CNN module.
    """
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=300)(x)

        return x

In [None]:
target = nl.models.FlaxModel(
            flax_module=RNDModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )
predictor = nl.models.FlaxModel(
            flax_module=RNDModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )

In [None]:
rng_agent = nl.agents.RandomAgent(data_generator=data_generator)

In [None]:
ds = rng_agent.build_dataset(300)

In [None]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], rng_agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], rng_agent.target_indices, axis=0)
}

In [None]:
training_strategy = nl.training_strategies.SimpleTraining(
    model=production_model, 
    loss_fn=nl.loss_functions.MeanPowerLoss(order=2),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
)
training_strategy.train_model(
    train_ds=train_ds, 
    test_ds=data_generator.test_ds,
    epochs=100,
    batch_size=50
)