In [18]:
import re
import tensorflow as tf
from tensorflow import keras
from keras import layers
from pathlib import Path
from dataclasses import dataclass, field, replace, asdict
from typing import List, Optional
import matplotlib.pyplot as plt

@dataclass
class DataConfig:
    data_dir: str = 'data'
    img_size: int = 224
    batch_size: int = 32
    validation_split: float = 0.2
    seed: int = 123

@dataclass
class ModelConfig:
    backbone: str = 'EfficientNetB0'
    pretrained: bool = True
    unfreeze_blocks: int = 2
    blocks_to_unfreeze: Optional[List[int]] = None

@dataclass
class TrainingConfig:
    epochs: int = 10
    learning_rate: float = 1e-4
    optimizer: str = 'adam'
    loss: str = 'binary_crossentropy'
    metrics: List[str] = field(default_factory=lambda: ['accuracy'])

@dataclass
class CallbackConfig:
    tensorboard_logdir: str = 'logs/'
    checkpoint_dir: str = 'checkpoints/'

@dataclass
class ExperimentConfig:
    data: DataConfig = field(default_factory=DataConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    callbacks: CallbackConfig = field(default_factory=CallbackConfig)


def print_block(title: str, data: dict):
    labels = {k: (f"{v:.2f}" if isinstance(v, float) else str(v)) for k, v in data.items()}
    lbl_w = max(len(k) for k in labels)
    val_w = max(len(v) for v in labels.values())
    total_w = lbl_w + 2 + val_w + 2
    title_str = f" {title} "
    border = "=" * max(len(title_str), total_w)
    print(f"\n{border}")
    print(title_str.center(len(border)))
    print(border)
    for k, v in labels.items():
        print(f"{k.ljust(lbl_w)} : {v.rjust(val_w)}")
    print(border + "\n")


def summarize_model(base):
    block_layers = {}
    block_trainable = {}
    for layer in base.layers:
        m = re.match(r"^block(\d+)[a-z]?_", layer.name)
        blk = int(m.group(1)) if m else 0
        block_layers[blk] = block_layers.get(blk, 0) + 1
        block_trainable[blk] = block_trainable.get(blk, False) or layer.trainable
    rows = []
    for blk in sorted(block_layers):
        name = f"block{blk}" if blk > 0 else "stem/head"
        rows.append((name, str(block_layers[blk]), 'Yes' if block_trainable[blk] else 'No'))
    headers = ('Block', 'Layers', 'Trainable')
    col1 = max(len(r[0]) for r in rows + [headers])
    col2 = max(len(r[1]) for r in rows + [headers])
    col3 = max(len(r[2]) for r in rows + [headers])
    total_w = col1 + col2 + col3 + 6
    border = '=' * total_w
    print(f"\n{border}")
    print(f"{headers[0].ljust(col1)} | {headers[1].rjust(col2)} | {headers[2].rjust(col3)}")
    print(border)
    for name, cnt, tf in rows:
        print(f"{name.ljust(col1)} | {cnt.rjust(col2)} | {tf.rjust(col3)}")
    print(border + "\n")


def prepare_dataset(cfg: ExperimentConfig):
    dc = cfg.data
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        dc.data_dir,
        labels='inferred',
        label_mode='binary',
        batch_size=dc.batch_size,
        image_size=(dc.img_size, dc.img_size),
        validation_split=dc.validation_split,
        subset='training',
        seed=dc.seed
    )
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        dc.data_dir,
        labels='inferred',
        label_mode='binary',
        batch_size=dc.batch_size,
        image_size=(dc.img_size, dc.img_size),
        validation_split=dc.validation_split,
        subset='validation',
        seed=dc.seed
    )
    return train_ds.prefetch(tf.data.AUTOTUNE), val_ds.prefetch(tf.data.AUTOTUNE)


def build_model(cfg: ExperimentConfig):
    mc = cfg.model
    img_size = cfg.data.img_size
    base = getattr(keras.applications, mc.backbone)(
        include_top=False,
        weights='imagenet' if mc.pretrained else None,
        input_shape=(img_size, img_size, 3)
    )
    block_nums = {int(m.group(1)) for layer in base.layers
                  if (m := re.match(r"^block(\d+)[a-z]?_", layer.name))}
    unique_blocks = sorted(block_nums)
    if mc.blocks_to_unfreeze:
        target_blocks = mc.blocks_to_unfreeze
    else:
        target_blocks = unique_blocks[-mc.unfreeze_blocks:]
    for layer in base.layers:
        m = re.match(r"^block(\d+)[a-z]?_", layer.name)
        freeze = not (m and int(m.group(1)) in target_blocks and not isinstance(layer, layers.BatchNormalization))
        layer.trainable = not freeze
    inp = keras.Input((img_size, img_size, 3))
    x = base(inp, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(1, activation='sigmoid')(x)
    model = keras.Model(inp, out)
    return model, base


def compile_and_train(model: keras.Model, train_ds, val_ds, cfg: ExperimentConfig):
    tc = cfg.training
    cb = cfg.callbacks
    opt = getattr(keras.optimizers, tc.optimizer.capitalize())(learning_rate=tc.learning_rate)
    model.compile(optimizer=opt, loss=tc.loss, metrics=tc.metrics)
    callbacks = []
    if cb.tensorboard_logdir:
        callbacks.append(keras.callbacks.TensorBoard(log_dir=cb.tensorboard_logdir))
    if cb.checkpoint_dir:
        p = Path(cb.checkpoint_dir)
        p.mkdir(parents=True, exist_ok=True)
        callbacks.append(
            keras.callbacks.ModelCheckpoint(
                filepath=str(p/'ckpt_{epoch}.keras'), save_best_only=True,
                monitor='val_loss')
        )
    return model.fit(train_ds, validation_data=val_ds, epochs=tc.epochs, callbacks=callbacks)


def plot_history(history):
    hist = history.history
    epochs = range(1, len(hist['loss']) + 1)
    plt.figure()
    plt.plot(epochs, hist['loss'], label='Training Loss')
    plt.plot(epochs, hist['val_loss'], label='Validation Loss')
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    if 'accuracy' in hist:
        plt.figure()
        plt.plot(epochs, hist['accuracy'], label='Training Accuracy')
        plt.plot(epochs, hist['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.show()


def run_experiment(**overrides):
    cfg = ExperimentConfig()
    for section, params in overrides.items():
        if hasattr(cfg, section) and isinstance(params, dict):
            old = getattr(cfg, section)
            new = replace(old, **params)
            cfg = replace(cfg, **{section: new})
        else:
            raise ValueError(f"Unknown section '{section}' or invalid params")
    cfg_dict = {}
    flat = asdict(cfg)
    for sec, sec_vals in flat.items():
        for k, v in sec_vals.items():
            cfg_dict[f"{sec}.{k}"] = v
    print_block("Experiment Config", cfg_dict)
    train_ds, val_ds = prepare_dataset(cfg)
    model, base = build_model(cfg)
    summarize_model(base)
    history = compile_and_train(model, train_ds, val_ds, cfg)
    plot_history(history)

In [16]:
exp = {}

In [None]:
# SAMPLE EXPERIMENT: Do not use as an actual experiment
exp['SAMPLE'] = run_experiment(
    training={
        'metrics': ['auc', 'accuracy', 'precision', 'recall']
    }
)