# Model Pre Training

### Idea 

Train a classification model using human radiograph images, and then fine tune on veterinary medical images

For the dataset, labels were extracted from the radiologist report where: blank for unmentioned, 0 for negative, -1 for uncertain, and 1 for positive. See [here](https://stanfordmlgroup.github.io/competitions/chexpert/)

## Data Processing

In [1]:
import os
from glob import glob
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.cm as cm
from PIL import Image

from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers

from keras_cv_attention_models import coatnet
from keras_cv_attention_models import swin_transformer_v2
from keras_cv_attention_models import nfnets
from keras_cv_attention_models import maxvit

from typing import List

from data_preparation import prepare_data
from model_utils import Lion

### Other models to try:

- CAFormer
- EfficientNetV2M
- VOLO
- **Dino v2 base and large (base weights are downloaded)
- **EfficientNetV1B7
- TinyViT (or small convnext model) on 1024 x 1024 (without pretrained weights)

### Try using object detection pretrained model

- YOLO
- **EfficientDetD7

The code below can be used to verify that the gpu is in use

In [2]:
!nvidia-smi

Failed to initialize NVML: Unknown Error


In [3]:
TARGET_HEIGHT = 640
TARGET_WIDTH = 640

In [None]:
X_train, y_train, train_labels_df, targets = prepare_data(split='train')
X_val, y_val, valid_labels_df, _ = prepare_data(split='valid')

In [None]:
train_labels_df.shape

In [None]:
valid_labels_df.shape

In [None]:
chexnet_targets = ['No Finding',
       'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
       'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
       'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
       'Support Devices']

chexpert_targets = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Pleural Effusion']

### Uncertainty Approaches
The CheXpert paper outlines several different approaches to mapping using the uncertainty labels in the data:

- Ignoring - essentially removing from the calculation in the loss function
- Binary mapping - sending uncertain values to either 0 or 1
- Prevalence mapping - use the rate of prevelance of the feature as it's target value
- Self-training - consider the uncertain values as unlabeled
- 3-Class Classification - retain a separate value for uncertain and try to predict it as a class in its own right

The paper gives the results of different experiments with the above approaches and indicates the most accurate approach for each feature.
    
|Approach/Feature|Atelectasis|Cardiomegaly|Consolidation|Edema|PleuralEffusion|
|-----------|-----------|-----------|-----------|-----------|-----------|
|`U-Ignore`|0.818(0.759,0.877)|0.828(0.769,0.888)|0.938(0.905,0.970)|0.934(0.893,0.975)|0.928(0.894,0.962)|
|`U-Zeros`|0.811(0.751,0.872)|0.840(0.783,0.897)|0.932(0.898,0.966)|0.929(0.888,0.970)|0.931(0.897,0.965)|
|`U-Ones`|**0.858(0.806,0.910)**|0.832(0.773,0.890)|0.899(0.854,0.944)|0.941(0.903,0.980)|0.934(0.901,0.967)|
|`U-Mean`|0.821(0.762,0.879)|0.832(0.771,0.892)|0.937(0.905,0.969)|0.939(0.902,0.975)|0.930(0.896,0.965)|
|`U-SelfTrained`|0.833(0.776,0.890)|0.831(0.770,0.891)|0.939(0.908,0.971)|0.935(0.896,0.974)|0.932(0.899,0.966)|
|`U-MultiClass`|0.821(0.763,0.879)|**0.854(0.800,0.909)**|0.937(0.905,0.969)|0.928(0.887,0.968)|0.936(0.904,0.967)|

The binary mapping approaches (U-Ones and U-Zeros) are easiest to implement and so to begin with we take the best option between U-Ones and U-Zeros for each feature

- Atelectasis `U-Ones`
- Cardiomegaly `U-Zeros`
- Consolidation `U-Zeros`
- Edema `U-Ones`
- Pleural Effusion `U-Zeros`

In [None]:
train_labels_df['valid'] = False
valid_labels_df['valid'] = True

In [None]:
full_df = pd.concat([train_labels_df, valid_labels_df])
full_df.head()

### View a sample of images and labels

In [None]:
#get the first 5 images
paths =  full_df.path[:5]
labels = full_df.feature_string[:5]

fig, m_axs = plt.subplots(1, len(labels), figsize = (20, 10))
#show the images and label them
for ii, c_ax in enumerate(m_axs):
    c_ax.imshow(np.asarray(Image.open(paths[ii])), cmap='gray')
    c_ax.set_title(labels[ii])
plt.show()

## Original Image

In [None]:
img = np.asarray(Image.open(X_val[0]))

In [None]:
img = np.stack((img,)*3, axis=-1)

In [None]:
img.shape

In [None]:
f, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.imshow(img)
ax.axis('off')
ax.set_aspect('auto')

plt.show() 

## Augmented Image

In [None]:
def apply_mask(image, size=12, n_squares=1):
    h, w, channels = image.shape
    new_image = np.asarray(image.copy())
    for _ in range(n_squares):
        y = np.random.randint(h)
        x = np.random.randint(w)
        y1 = np.clip(y - size // 2, 0, h)
        y2 = np.clip(y + size // 2, 0, h)
        x1 = np.clip(x - size // 2, 0, w)
        x2 = np.clip(x + size // 2, 0, w)
        new_image[y1:y2, x1:x2, :] = 0
    return new_image

In [None]:
if np.random.uniform() < 0.5:
    augmented = apply_mask(img, size=np.random.randint(low=70, high=240), n_squares=np.random.randint(low=2, high=12))
else:
    augmented = tf.image.random_brightness(img, max_delta=0.2)
    augmented = tf.image.random_saturation(image=augmented, lower=0.8, upper=1.2)
    augmented = tf.image.random_hue(image=augmented, max_delta=0.03)
    augmented = tf.image.random_contrast(image=augmented, lower=0.8, upper=1.2)
# augmented = tf.image.random_flip_up_down(img)
# augmented = tf.image.random_flip_left_right(img)
# augmented = tf.image.random_saturation(image=img, lower=0.7, upper=1.3)
# augmented = tf.image.random_hue(image=img, max_delta=0.03)
# augmented = tf.image.random_contrast(image=img, lower=0.7, upper=1.3)

In [None]:
augmented.shape

In [None]:
f, ax = plt.subplots(1, 1, figsize=(5, 5))

ax.imshow(augmented, cmap='gray')
ax.axis('off')
ax.set_aspect('auto')

plt.show() 

In [None]:
NUM_TRAIN = len(X_train)

## Determine class weights 

In [None]:
target_columns = [col_name + '_label' for col_name in targets]

In [None]:
plt.xticks(rotation=90)
plt.bar(x=target_columns, height=y_train.sum(axis=0))
plt.show()

In [None]:
class_counts = y_train.sum(axis=0)
total_count = y_train.sum()

In [None]:
cls_weights = {i: total_count/class_i_count for i, class_i_count in enumerate(class_counts)}

In [None]:
cls_weights

In [None]:
cls_weights_sqrt = {i: np.sqrt(weight) for i, weight in enumerate(list(cls_weights.values()))}

In [None]:
cls_weights_sqrt

In [None]:
cls_weights_log = {i: np.log(weight) for i, weight in enumerate(list(cls_weights.values()))}

In [None]:
cls_weights_log

In [None]:
y_train_weighted = y_train * np.array(list(cls_weights.values()))

In [None]:
plt.xticks(rotation=90)
plt.bar(x=target_columns, height=y_train_weighted.sum(axis=0))
plt.show()

In [None]:
y_train_weighted_sqrt = y_train * np.array(list(cls_weights_sqrt.values()))

In [None]:
plt.xticks(rotation=90)
plt.bar(x=target_columns, height=y_train_weighted_sqrt.sum(axis=0))
plt.show()

In [None]:
y_train_weighted_log = y_train * np.array(list(cls_weights_log.values()))

In [None]:
plt.xticks(rotation=90)
plt.bar(x=target_columns, height=y_train_weighted_log.sum(axis=0))
plt.show()

In [None]:
USE_CLASS_WEIGHTS = True

In [None]:
if USE_CLASS_WEIGHTS:
    CLASS_WEIGHTS = cls_weights_log
else:
    CLASS_WEIGHTS = None

## Define dataset generator

In [None]:
def convert_image_to_array(path):
    img = np.asarray(Image.open(path), dtype=np.float32)
    img = np.stack((img,)*3, axis=-1)
    img /= 255.
    img = tf.image.resize_with_pad(img, target_height=TARGET_HEIGHT, target_width=TARGET_WIDTH)
    return img

In [None]:
def create_model_file(X_path, y):
    """
    X_path: (pandas series) contains the file paths to the images
    y: (pandas series of type int) the target label
    
    return a pair of numpy arrays representing (features, target)
    """
    
    X = pd.Series(X_path).apply(convert_image_to_array)
    X = X.values
    X = list(X)
    X = np.array(X, dtype='float32')
    
    return (X, y)

In [None]:
def model_predict(path, model):
    x = convert_image_to_array(path=path)
    x = np.expand_dims(x, axis=0)
    return model.predict(x)

In [None]:
val_data = create_model_file(X_path=X_val, y=y_val)

In [None]:
val_data[0].shape

In [None]:
Image.fromarray(np.uint8(255 * val_data[0][1]))

In [None]:
num_classes = y_train.shape[1]

In [None]:
def transform_image(img_path, target_image_size, dtype, scale_image):
    # read the image
    img = np.asarray(Image.open(img_path), dtype=dtype)
    img = np.stack((img,)*3, axis=-1)

    # add image augmentation
    if np.random.uniform() < 0.5:
        img = apply_mask(img, size=np.random.randint(low=70, high=240), n_squares=np.random.randint(low=2, high=12))
    else:
        if np.random.uniform() < 0.15:
            img = tf.image.random_brightness(img, max_delta=0.2)
        if np.random.uniform() < 0.15:
            img = tf.image.random_saturation(image=img, lower=0.8, upper=1.2)
        if np.random.uniform() < 0.15:
            img = tf.image.random_hue(image=img, max_delta=0.03)
        if np.random.uniform() < 0.15:
            img = tf.image.random_contrast(image=img, lower=0.8, upper=1.2)

    if scale_image:
        img = img/255.

    # resize image
    img = tf.image.resize_with_pad(img, target_height=target_image_size[0], target_width=target_image_size[1])
    
    return img

#### Visualize some transformed images 

In [None]:
img_path = X_val[1]

Original

In [None]:
Image.open(img_path)

In [None]:
img = transform_image(img_path=img_path, target_image_size=(TARGET_HEIGHT, TARGET_WIDTH), dtype=np.float32, scale_image=True)

In [None]:
Image.fromarray(np.uint8(255 * img.numpy()))

In [None]:
def data_gen(X, y, batch_size, image_size=(TARGET_HEIGHT, TARGET_WIDTH), dtype=np.float32, scale_image=True):
    # Get total number of samples in the data
    n = len(X)
    steps = n//batch_size
    
    # Define two numpy arrays for containing batch data and labels
    batch_data = np.zeros((batch_size, image_size[0], image_size[1], 3), dtype=dtype)
    batch_labels = np.zeros((batch_size, num_classes), dtype=dtype)

    # Get a numpy array of all the indices of the input data
    indices = np.arange(n)
    
    # Initialize a counter
    i = 0
    while True:
        np.random.shuffle(indices)
        # Get the next batch 
        count = 0
        next_batch = indices[(i*batch_size):(i+1)*batch_size]
        for j, idx in enumerate(next_batch):
            img_path = X[idx]
            label = y[idx]
            
            # one hot encoding
            encoded_label = label
            
            # Transform/augment the image
            img = transform_image(img_path=img_path, target_image_size=image_size, dtype=dtype, scale_image=scale_image)
            
            batch_data[count] = img
            batch_labels[count] = encoded_label

            count+=1

            if count==batch_size:
                break
            
        i+=1
        yield batch_data, batch_labels
            
        if i>=steps:
            i=0

## Keras Utility Functions

Define some functions that will help simplify the fine-tuning pre-trained models

In [None]:
def freeze_layers(model, freeze_layer_name):
    for layer in model.layers:
        if layer.name != freeze_layer_name:
            layer.trainable = False
        else:
            layer.trainable = False
            break
            
def unfreeze_batch_norm(model):
    for layer in model.layers:
        if layer.__class__.__name__ == 'BatchNormalization':
            layer.trainable = True
            
def unfreeze_layer_norm(model):
    for layer in model.layers:
        if layer.__class__.__name__ == 'LayerNormalization':
            layer.trainable = True

def print_layer_trainable(model):
    for layer in model.layers:
        print('{0}:\t{1}'.format(layer.trainable, layer.name))

## (Optional) Load pretrained model

In [None]:
# List all models
!ls ../models

In [None]:
MODEL_BASE_NAME = 'tiny_vit_21m_512_imagenet21k-ft1k'
MODEL_NAME = f'{MODEL_BASE_NAME}.h5'
model_path = f'../models/{MODEL_NAME}'
model = tf.keras.models.load_model(model_path)

In [None]:
model.summary()

## Build the model

In [None]:
# model = swin_transformer_v2.SwinTransformerV2Base_window24(input_shape=(TARGET_HEIGHT, TARGET_WIDTH, 3))

In [None]:
# model = coatnet.CoAtNet2(input_shape=(TARGET_HEIGHT, TARGET_WIDTH, 3))

In [None]:
# model = nfnets.ECA_NFNetL3(input_shape=(TARGET_HEIGHT, TARGET_WIDTH, 3))

In [None]:
# model = maxvit.MaxViT_Small(input_shape=(TARGET_HEIGHT, TARGET_WIDTH, 3))

In [None]:
# model.layers[-6].name

In [None]:
# model.count_params()

## Load Pretrained Model (alternative)

In [None]:
# # model_path='./serialized_models/pretrain_model_ConvNeXtBase_w_ClssWgt_01-0.3887.h5'
# model_path='./serialized_models/pretrain_model_ConvNeXtBase_w_ClssWgt_01-0.4021.h5'
# # model_path='./serialized_models/pretrain_model_ConvNeXtSmall_w_ClssWgt_03-0.5216.h5'

In [None]:
# from keras.applications.convnext import LayerScale
# model = tf.keras.models.load_model(model_path, custom_objects={'LayerScale': LayerScale})

Determine where to freeze and cut off base model

In [None]:
# # Note: these were used to train SwinTransformerV2Base_window24
# transfer_layer_name = 'pre_output_ln'
# transfer_layer = model.get_layer(transfer_layer_name)

In [None]:
# # Note: these were used to train CoAtNet2
# transfer_layer_name = 'stack_4_block_2_ffn_output'
# transfer_layer = model.get_layer(transfer_layer_name)

In [None]:
# # Note: these were used to train ECA_NFNetL3
# transfer_layer_name = 'post_swish'
# transfer_layer = model.get_layer(transfer_layer_name)

In [None]:
# # Note: these were used to train MaxViT_Small
# transfer_layer_name = 'stack_4_block_2/grid_ffn_output'
# transfer_layer = model.get_layer(transfer_layer_name)
# freeze_layer_name = 'stack_4_block_2/grid_ffn_output'

In [None]:
# # Note: these were used to train YOLOV8 X
# transfer_layer_name = 'tf.concat_15'
# transfer_layer = model.get_layer(transfer_layer_name)
# freeze_layer_name = 'tf.concat_15'

In [None]:
# # Note: these were used to train YOLOV8 X6
# transfer_layer_name = 'tf.concat_45'
# transfer_layer = model.get_layer(transfer_layer_name)
# freeze_layer_name = 'head_4_cls_3_conv' # Also try going back to this layer: head_4_cls_2_swish

In [None]:
# # Note: these were used to train efficientnetv1-b7-noisy_student
# transfer_layer_name = 'post_swish'
# transfer_layer = model.get_layer(transfer_layer_name)
# freeze_layer_name = 'post_swish' 

In [None]:
# # Note: these were used to train tiny vit
transfer_layer_name = 'stack4_block2_mlp_output'
transfer_layer = model.get_layer(transfer_layer_name)
freeze_layer_name = 'stack4_block2_mlp_output' 

In [None]:
# model.summary()

In [None]:
conv_model = tf.keras.Model(inputs=model.input, outputs=transfer_layer.output)

In [None]:
conv_model.summary()

2 ideas for adapting YOLO:
1. Just do pooling/flatten and then concatenate the various levels of granularity
2. Add convolution/pooling layers to get everything to match the same dimension then do pooling/flatten

In [None]:
def build_yolo_model(base_model, num_classes, dropout_rate=0):
    # Get the output of the base model on which we will build
    x_80 = base_model.get_layer('tf.concat_41')
    x_40 = base_model.get_layer('tf.concat_42')
    x_20 = base_model.get_layer('tf.concat_43')
    x_10 = base_model.get_layer('tf.concat_44')
    
    x_80 = layers.AveragePooling2D((2,2), name='local_avg_pool_80')(x_80.output)
    x_40 = layers.AveragePooling2D((2,2), name='local_avg_pool_40')(x_40.output)
    x_20 = layers.AveragePooling2D((2,2), name='local_avg_pool_20')(x_20.output)
    x_10 = layers.AveragePooling2D((2,2), name='local_avg_pool_10')(x_10.output)
    
    x_80 = layers.Flatten(name='flatten_80')(x_80)
    x_40 = layers.Flatten(name='flatten_40')(x_40)
    x_20 = layers.Flatten(name='flatten_20')(x_20)
    x_10 = layers.Flatten(name='flatten_10')(x_10)

    x_80 = keras.layers.Dropout(dropout_rate)(x_80)
    x_40 = keras.layers.Dropout(dropout_rate)(x_40)
    x_20 = keras.layers.Dropout(dropout_rate)(x_20)
    x_10 = keras.layers.Dropout(dropout_rate)(x_10)
    
    x = tf.concat([x_80, x_40, x_20, x_10], axis=-1)
        
    x = layers.Dense(num_classes, activation='sigmoid', name='prediction')(x)

    # Create model.
    model = tf.keras.Model(base_model.input, x, name='Xception')
    return model

In [None]:
def build_model(base_model, num_classes, pooling='avg', final_conv_layer='vgg_separable', expand_model=True, dropout_rate=0):
    # Get the output of the base model on which we will build
    x = base_model.layers[-1].output
    
    if expand_model:
        if final_conv_layer == 'xception':
            x = layers.SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
            x = layers.BatchNormalization(name='block14_sepconv2_bn')(x)
            x = layers.Activation('relu', name='block14_sepconv2_act')(x)
            x = keras.layers.Dropout(dropout_rate / 2)(x)
        elif final_conv_layer == 'non_separable':
            x = layers.Conv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_conv2')(x)
            x = layers.BatchNormalization(name='block14_conv2_bn')(x)
            x = layers.Activation('relu', name='block14_conv2_act')(x)
            x = keras.layers.Dropout(dropout_rate / 2)(x)
        elif final_conv_layer == 'vgg_separable':
            x = layers.SeparableConv2D(2048, (3,3), activation='relu', padding='same', kernel_regularizer='l1_l2', bias_regularizer='l1_l2', name='block14_sepconv2')(x)
            x = keras.layers.Dropout(dropout_rate / 2)(x)
        elif final_conv_layer == 'vgg':
            x = layers.Conv2D(2048, (3,3), activation='relu', padding='same', name='block14_sepconv2')(x)
            x = keras.layers.Dropout(dropout_rate / 2)(x)
        else:
            raise ValueError('`final_conv_layer` should be one of the following: xception, non_separable, vgg_separable, or vgg')

    if pooling == 'global_avg':
        x = layers.GlobalAveragePooling2D(name='global_avg_pool')(x)
    elif pooling == 'global_max':
        x = layers.GlobalMaxPooling2D(name='global_max_pool')(x)
    elif pooling == 'max':
        x = layers.MaxPooling2D((2,2), name='local_max_pool')(x)
        x = layers.Flatten(name='flatten')(x)
    elif pooling == 'avg':
        x = layers.AveragePooling2D((2,2), name='local_avg_pool')(x)
        x = layers.Flatten(name='flatten')(x)
    else:
        pass
    x = keras.layers.Dropout(dropout_rate)(x)
        
    x = layers.Dense(num_classes, activation='sigmoid', name='prediction')(x)

    # Create model.
    model = tf.keras.Model(base_model.input, x, name='Xception')
    return model

## Determine good starting learning rate

Experiment with the proper learning rate range by starting at a low number, see how many epochs for loss to get to a certain value, incrementally increase until the learning rate is too high. Use this range to determine the initial learning rate.

Create a function to do this analysis

In [None]:
def determine_learning_rate(X, y, batch_size: int, lr_list: List[float], steps: int):
    train_loss_by_lr = []
    
    for i, lr in enumerate(lr_list):
        print(f'Learning rate {i + 1} of {len(lr_list)}. LR value: {lr}')
        local_model = ConvNeXtBase(include_top=False, weights='imagenet', input_tensor=input_tensor, include_preprocessing=False)
        local_conv_model = tf.keras.Model(inputs=local_model.input, outputs=local_model.output)
        
        model_lr = build_model(base_model=local_conv_model, num_classes=num_classes, dropout_rate=0.3)
        
        model_lr.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
            loss="binary_crossentropy",
            metrics=["accuracy"],
        )
        
        hist = model_lr.fit(
            x=data_gen(X=X, y=y, batch_size=batch_size), 
            epochs=1, 
            steps_per_epoch=steps,
            class_weight=CLASS_WEIGHTS,
        )
        
        final_loss = hist.history['loss'][-1]
        final_acc = hist.history['accuracy'][-1]
        
        train_loss_by_lr.append((lr, final_loss, final_acc))
      
    losses_df = pd.DataFrame(train_loss_by_lr, columns=['learning_rate', 'training_loss', 'training_accuracy'])
    
    return losses_df

In [None]:
# lr_loss_df = determine_learning_rate(X=X_train, 
#                                      y=y_train, 
#                                      batch_size=2, 
#                                      lr_list=[1e-2, 5e-3, 1e-3, 5e-4, 1e-4, 5e-5, 1e-5, 5e-6, 1e-6, 1e-7, 1e-8,],
#                                      steps=200)

In [None]:
# plt.scatter(x=lr_loss_df['learning_rate'].values, y=lr_loss_df['training_loss'].values)
# plt.xscale('log')
# plt.show()

## Train model

In [None]:
# Ideally retrain the entire model, but memory is constrained 
freeze_layers(conv_model, freeze_layer_name)

In [None]:
model = build_model(base_model=conv_model, num_classes=num_classes, dropout_rate=0.3)
# model = build_yolo_model(base_model=conv_model, num_classes=num_classes, dropout_rate=0.3)

In [None]:
model.summary()

In [None]:
# unfreeze_batch_norm(model)

In [None]:
# unfreeze_layer_norm(model)

In [None]:
print_layer_trainable(model)

## Lion Optimizer

[Lion](https://arxiv.org/pdf/2302.06675.pdf) is a new optimizer that helps to converge more quickly to better models with better memory efficiency. The official implementation is [here](https://github.com/google/automl/blob/master/lion/lion_tf2.py).

## Train the model

In [None]:
batch_size = 16

In [None]:
lr_schedule = keras.optimizers.schedules.CosineDecayRestarts(
      # initial_learning_rate=3e-3, # for uniform weighting
      # initial_learning_rate=3e-4, # for sqrt weighting 
      initial_learning_rate=1e-5, # lower learning rate for Lion optimizer
      first_decay_steps=int(NUM_TRAIN/ (7 * batch_size)))

In [None]:
# Cycle through cosine decay with restarts 7 times per epoch
print(NUM_TRAIN/ (7 * batch_size))
print(int(NUM_TRAIN/ (7 * batch_size)))

In [None]:
epochs = 10
model_path='../models/pretrain_model_tiny_vit_21m_512_imagenet21k-ft1k_w_ClssWgt_{epoch:02d}-{val_loss:.4f}.h5'

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(filepath=model_path, save_best_only=False),
#     tf.keras.callbacks.EarlyStopping(patience=10)
]
model.compile(
#     optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
    optimizer=Lion(learning_rate=lr_schedule),
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
model.fit(
    x=data_gen(X=X_train, y=y_train, batch_size=batch_size), 
    epochs=epochs, 
    callbacks=callbacks, 
    validation_data=val_data, 
    validation_batch_size=4,
    steps_per_epoch=int(NUM_TRAIN/batch_size),
    class_weight=CLASS_WEIGHTS,
)

- SwinV2: 1000/111707: loss: 1.446, accuracy: 0.0624
- SwinV2: 8000/111707: loss: 1.376, accuracy: 0.0392
- CoAtNet2: 1000/111707: loss: 1.468, accuracy: 0.0750
- CoAtNet2: 4200/111707: loss: 1.371, accuracy: 0.0840
- NFNet (w/ Lion): 500/55853: loss: 1.465, accuracy: 0.0319
- NFNet (w/ Lion): 2100/55853: loss: 1.37, accuracy: 0.0433
- NFNet (w/ Lion): 4673/55853: loss: 1.33, accuracy: 0.0596
- MaxViT Small (w/ Lion) 640, norm layers unfrozen 1e-5 lr: 500/111707: loss: 1.465, accuracy: 0.0319
- MaxViT Small (w/ Lion) 640, norm layers unfrozen 1e-5 lr: 5500/111707: loss: 1.247, accuracy: 0.1242
- MaxViT Small (w/ Lion) 640, norm layers unfrozen 1e-5 lr: 13500/111707: loss: 1.196, accuracy: 0.15

- MaxViT Base (w/ Lion) 512 imagenet21k weights, log class weights, norm layers frozen 3e-5 lr: 3000/55853: loss: 1.0627, accuracy: 0.0976
- MaxViT Base (w/ Lion) 512 imagenet21k weights, log class weights, norm layers unfrozen 1e-5 lr: 4000/74471: loss: 1.4759, accuracy: 0.1263

- YOLOV8 X6 (w/ Lion) 640 coco weights, log class weights, norm layers frozen 2e-5 lr, batch size 8: 500/27926: loss: 1.54, accuracy: 0.13
- YOLOV8 X6 (w/ Lion) 640 coco weights, log class weights, norm layers frozen 2e-5 lr, batch size 8: 2000/27926: loss: 1.476, accuracy: 0.1345
- YOLOV8 X6 (w/ Lion) 640 coco weights, log class weights, norm layers frozen 2e-5 lr, batch size 8: 5500/27926: loss: 1.335, accuracy: 0.141 

- EfficientNet B7 (w/ Lion) 600 imagenet weights, log class weights, norm layers frozen 2e-5 lr, batch size 8: 500/27926: loss: 0.953, accuracy: 0.028
- EfficientNet B7 (w/ Lion) 600 imagenet weights, log class weights, norm layers frozen 2e-5 lr, batch size 8: 2000/27926: loss: 0.904, accuracy: 0.031

- TinyViT (w/ Lion) 512 imagenet21k weights, log class weights, norm layers frozen 1e-5 lr, batch size 16: 340/13963: loss: 0.941, accuracy: 0.0339
- TinyViT (w/ Lion) 512 imagenet21k weights, log class weights, norm layers frozen 1e-5 lr, batch size 16: 800/13963: loss: 0.88, accuracy: 0.056
- TinyViT (w/ Lion) 512 imagenet21k weights, log class weights, norm layers frozen 1e-5 lr, batch size 16: 2700/13963: loss: 0.828, accuracy: 0.108
- TinyViT (w/ Lion) 512 imagenet21k weights, log class weights, norm layers frozen 1e-5 lr, batch size 16: 4000/13963: loss: 0.8113, accuracy: 0.1215

- Old convnet: 13300/55853, loss: 0.3006, accuracy: 0.2011

Save final model - Make sure name is correct!

In [None]:
model.save('./serialized_models/pretrain_model_MaxViT_w_ClssWgt_01-0.unk.h5')