In [None]:
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import numpy as np

class PeriodicEvaluation(Callback):
    def __init__(self, test_gen, eval_every=5):
        super().__init__()
        self.test_gen = test_gen
        self.eval_every = eval_every
        self.test_accuracies = []
        self.test_losses = []
    
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.eval_every == 0:
            # Evaluate on test set
            test_loss, test_acc = self.model.evaluate(self.test_gen, verbose=0)
            self.test_accuracies.append(test_acc)
            self.test_losses.append(test_loss)
            
            print(f"\n🎯 Epoch {epoch+1} - Test Accuracy: {test_acc:.4f}, Test Loss: {test_loss:.4f}")
            
            # Also compare with validation performance
            val_acc = logs.get('val_accuracy', 0)
            generalization_gap = logs.get('accuracy', 0) - val_acc
            print(f"📊 Generalization Gap (Train-Val): {generalization_gap:.4f}")
            print(f"📊 Train-Val-Test: {logs.get('accuracy', 0):.4f} | {val_acc:.4f} | {test_acc:.4f}")

# Add to your callbacks
periodic_eval = PeriodicEvaluation(test_gen, eval_every=3)  # Evaluate every 3 epochs

callbacks = [early_stop, reduce_lr, checkpoint, periodic_eval]