# CIFAR10 Example

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import znnl as nl

import tensorflow_datasets as tfds

import numpy as np
from flax import linen as nn
import optax
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px

import jax
print(jax.default_backend())

In [None]:
data_generator = nl.data.CIFAR10Generator()

In [None]:
data_generator.plot_image(indices=[0, 1, 2, 3, 4])

In [None]:
class ProductionModule(nn.Module):
    """
    Simple CNN module.
    """

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=300)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)

        return x

In [None]:
production_model = nl.models.FlaxModel(
            flax_module=ProductionModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )

training_strategy = nl.training_strategies.SimpleTraining(
    model=production_model, 
    loss_fn=nl.loss_functions.CrossEntropyLoss(),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
)

In [None]:
batch_wise_training_metrics = training_strategy.train_model(
    train_ds={"inputs": data_generator.train_ds["inputs"], "targets": data_generator.train_ds["targets"]},
    test_ds={"inputs": data_generator.test_ds["inputs"], "targets": data_generator.test_ds["targets"]},
    batch_size=100,
)

## Random Data Selection

In [None]:
class RNDModule(nn.Module):
    """
    Simple CNN module.
    """
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=300)(x)

        return x

In [None]:
target = nl.models.FlaxModel(
            flax_module=RNDModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )
predictor = nl.models.FlaxModel(
            flax_module=RNDModule(),
            optimizer=optax.adam(learning_rate=0.01),
            input_shape=(1, 32, 32, 3),
        )

In [None]:
rng_agent = nl.agents.RandomAgent(data_generator=data_generator)

In [None]:
ds = rng_agent.build_dataset(300)

In [None]:
train_ds = {
    "inputs": np.take(data_generator.train_ds["inputs"], rng_agent.target_indices, axis=0),
    "targets": np.take(data_generator.train_ds["targets"], rng_agent.target_indices, axis=0)
}

In [None]:
training_strategy = nl.training_strategies.SimpleTraining(
    model=production_model, 
    loss_fn=nl.loss_functions.MeanPowerLoss(order=2),
    accuracy_fn=nl.accuracy_functions.LabelAccuracy(),
)
training_strategy.train_model(
    train_ds=train_ds, 
    test_ds=data_generator.test_ds,
    epochs=100,
    batch_size=50
)