## Installs, imports, etc.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_data_validation
import numpy as np

import pandas as pd
import seaborn as sns
import sklearn

## Load dataset 

In [None]:
(bee_ds, ), info = tfds.load(
    name='bee_dataset',
    with_info=True,
    split=['train', ],
    as_supervised=True,
    shuffle_files=False,
)

varroa_ds = bee_ds.map(lambda x, y: (x, y["varroa_output"]))

Calculate the `class_weights` to remedy class imbalance during training.

In [None]:
class_weights = dict(
    zip(
        # keys
        np.unique(np.fromiter(varroa_ds.map(lambda x, y: y), int)), 
        # weights
        sklearn.utils.class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(np.fromiter(varroa_ds.map(lambda x, y: y), float)),
            y=np.fromiter(varroa_ds.map(lambda x, y: y), float)
        ).T
    )
)

print(class_weights)

## Model engineering

We will use `EfficientNetV2S` architecure (to be trained from scratch!) to train a binary classifier for `varroa_ds`.

In [None]:
model = tf.keras.applications.EfficientNetV2S(
    include_top=True,
    weights=None,
    input_shape=info.features['input'].shape,
    pooling=None,
    classes=1,
    classifier_activation='sigmoid',
    include_preprocessing=True,
)

optimizer = tf.optimizers.Adam(
    learning_rate=0.001
)

model.compile(
    optimizer=optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy'],
    run_eagerly=True
)

## Model training on Ætna (but CPU only)

In [None]:
NUM_EPOCHS=2
BATCH_SIZE=256

Next, we build our training data input pipeline.

In [None]:
train_ds = bee_ds.map(lambda x, y: (x, y["varroa_output"])) \
                  .shuffle(bee_ds.cardinality()) \
                  .repeat(NUM_EPOCHS) \
                  .batch(BATCH_SIZE) \
                  .prefetch(tf.data.AUTOTUNE)

Let's fit our model and do not forget `class_weights`!

In [None]:
model.fit(
    train_ds,
    verbose=1,
    class_weight=class_weights
)