In [10]:
import keras
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from lib.vit_keras.vit_keras import vit, visualize
import os
import matplotlib.pyplot as plt

In [11]:
DATASET_DIR = "_data"
TRAIN_DIR = "train_b"
VAL_DIR = "val_b"
IMAGES_DIR = "images"
MASKS_DIR = "leaf_instances"
IMAGE_SIZE = (384, 384)
INPUT_SHAPE = IMAGE_SIZE + (3, )
CLASSES = 2
NAME = "VisionTransformer"
EPOCHS = 200
CHECKPOINT_DIR = f"checkpoints/{NAME}"
LOAD = "best"
BATCH_SIZE = 4

In [12]:
def gen_dataset(path, batch_size, lab, input_shape, aug=True):
    data_augmentation = tf.keras.Sequential([
        keras.layers.RandomFlip("horizontal_and_vertical"),
        keras.layers.RandomRotation(.8),
        keras.layers.RandomBrightness(.4),
        keras.layers.RandomContrast(.4),
        keras.layers.RandomZoom((-.2, .2), (-.2, .2)),
        keras.layers.Resizing(INPUT_SHAPE[0], INPUT_SHAPE[1])
    ])
    @tf.function()
    def augment(image, label):
        if (augment):
            return data_augmentation(image), to_categorical(label, num_classes=CLASSES)
        else:
            return x, to_categorical(label, num_classes=CLASSES)
    datagen = keras.utils.image_dataset_from_directory(path, batch_size=batch_size, image_size=input_shape[:2], crop_to_aspect_ratio=True, labels="inferred", label_mode="binary")
    if lab:
        datagen = datagen.map(
            lambda x, y: (transform_wrapper(x, target_size=INPUT_SHAPE[:2], rescale=True, smart_resize=True, lab=True), y)
        )
    datagen = datagen\
        .map(augment, num_parallel_calls=tf.data.AUTOTUNE, deterministic=False)\
        .prefetch(tf.data.AUTOTUNE)
    return datagen

In [13]:
if LOAD == 'latest':
    model_file = tf.train.latest_checkpoint(CHECKPOINT_DIR)
    print("Loading latest model:", model_file)
    model = keras.models.load_model(model_file, safe_mode=False)
if LOAD == 'best':
    model_file = f'out/best_{NAME}_{os.path.basename(DATASET_DIR)}.keras'
    print("Loading best model:", model_file)
    model = keras.models.load_model(model_file, safe_mode=False)
else:
    print("Loading initial model")
    model = vit.vit_l32(
        image_size=IMAGE_SIZE[0],
        activation='sigmoid',
        pretrained=False,
        include_top=True,
        pretrained_top=False,
        classes=CLASSES
    )

Loading best model: out/best_VisionTransformer__data.keras


In [14]:
train_dir = os.path.join(DATASET_DIR, TRAIN_DIR)
val_dir = os.path.join(DATASET_DIR, VAL_DIR)

In [15]:
train_dataset = gen_dataset(train_dir, BATCH_SIZE, False, IMAGE_SIZE)
val_dataset = gen_dataset(val_dir, BATCH_SIZE, False, IMAGE_SIZE, False)

Found 17387 files belonging to 2 classes.
Found 1332 files belonging to 2 classes.


In [16]:
def multiclass_iou_loss(y_true, y_pred, smooth=1e-6):
    """
    Compute the IoU loss for multiclass segmentation.

    :param y_true: True labels, one-hot encoded, shape (batch_size, height, width, num_classes)
    :param y_pred: Predictions, shape (batch_size, height, width, num_classes)
    :param smooth: Smoothing factor to avoid division by zero
    :return: Average IoU loss across all classes
    """
    num_classes = y_pred.shape[-1]
    iou_loss_per_class = []

    for c in range(num_classes):
        y_true_c = y_true[..., c]
        y_pred_c = y_pred[..., c]

        intersection = K.sum(y_true_c * y_pred_c)
        total = K.sum(y_true_c) + K.sum(y_pred_c)
        union = total - intersection

        iou = (intersection + smooth) / (union + smooth)
        iou_loss_per_class.append(1 - iou)

    return K.mean(tf.stack(iou_loss_per_class))

def combined_bce_iou_loss(y_true, y_pred):
    bce_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    iou = multiclass_iou_loss(y_true, y_pred)
    return bce_loss + iou

In [17]:
opt = keras.optimizers.SGD()
model.compile(
    loss='binary_crossentropy',
    optimizer=opt,
    metrics=[
        keras.metrics.BinaryAccuracy(),
        tf.keras.metrics.Recall(),
        tf.keras.metrics.AUC()
    ],
)

callbacks = [
    #keras.callbacks.EarlyStopping(patience=5),
    keras.callbacks.ModelCheckpoint(filepath=CHECKPOINT_DIR + '/model_##name##.{epoch:02d}_##data##.keras'.replace("##name##", NAME).replace('##data##', os.path.basename(DATASET_DIR))),
    keras.callbacks.TensorBoard(log_dir=f'./logs/{NAME}'),
    keras.callbacks.ModelCheckpoint(filepath='out/best_##name##_##data##.keras'.replace('##name##', NAME).replace('##data##', os.path.basename(DATASET_DIR)), save_best_only=True, mode='max', monitor='val_categorical_accuracy')
]

In [18]:
model.summary()

print(f"Beginning training of model {NAME}")

model.fit(train_dataset, epochs=1, callbacks=callbacks, validation_data=val_dataset)

print("Training finished, starting test evaluation")

result = model.evaluate(val_dataset)
print(result)

model.save('out/last_##name##_##data##.keras'.replace('##name##', NAME).replace('##data##', os.path.basename(DATASET_DIR))")

Beginning training of model VisionTransformer








[1m4347/4347[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m744s[0m 156ms/step - auc_1: 0.8891 - binary_accuracy: 0.8016 - loss: 0.4463 - recall_1: 0.8018 - val_auc_1: 0.6393 - val_binary_accuracy: 0.5886 - val_loss: 0.8699 - val_recall_1: 0.5863
Training finished, starting test evaluation
[1m333/333[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 216ms/step - auc_1: 0.6133 - binary_accuracy: 0.5695 - loss: 0.9254 - recall_1: 0.5678
[0.8560488820075989, 0.5964714884757996, 0.5968468189239502, 0.6484871506690979]


ValueError: Invalid filepath extension for saving. Please add either a `.keras` extension for the native Keras format (recommended) or a `.h5` extension. Use `model.export(filepath)` if you want to export a SavedModel for use with TFLite/TFServing/etc. Received: filepath=out/last_##name##_##data##.keras'.replace('##name##', NAME).replace('##data##', os.path.basename(DATASET_DIR)).

In [None]:
%matplotlib inline
# sample random image from validation dataset
im = val_dataset.take(1)

vis = visualize.attention_map(model, im)
plt.imshow(vis)
plt.show()