In [1]:
# ========= ViT tối giản cho phân loại 5 lớp =========
import os, math, tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# ---------------- Config ----------------
SEED = 42
tf.random.set_seed(SEED)

TRAIN_DIR = "/home/duc/Documents/DoAn/eyepacs_2015/train_preprocess_ben_graham"
VAL_DIR   = "/home/duc/Documents/DoAn/eyepacs_2015/val_preprocess_ben_graham"

NUM_CLASSES  = 5
IMG_SIZE     = 224          # Bạn có thể đổi (224, 256, 384...), miễn chia hết cho PATCH_SIZE
PATCH_SIZE   = 16           # 14x14 patch khi IMG_SIZE=224
EMBED_DIM    = 256          # kích thước ẩn của token
NUM_HEADS    = 8            # số head attention
MLP_DIM      = 512          # hidden size của MLP trong Transformer block
DEPTH        = 6            # số block Transformer
DROPOUT      = 0.1
BATCH_SIZE   = 32
EPOCHS       = 20

AUTOTUNE = tf.data.AUTOTUNE

assert IMG_SIZE % PATCH_SIZE == 0, "IMG_SIZE phải chia hết cho PATCH_SIZE"
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) * (IMG_SIZE // PATCH_SIZE)

# ---------------- Dataset ----------------
def make_ds(data_dir, subset="train"):
    shuffle = (subset == "train")
    ds = keras.utils.image_dataset_from_directory(
        data_dir,
        labels="inferred",
        label_mode="int",
        image_size=(IMG_SIZE, IMG_SIZE),
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        seed=SEED
    )
    return ds.prefetch(AUTOTUNE)

ds_train = make_ds(TRAIN_DIR, "train")
ds_val   = make_ds(VAL_DIR, "val")

# ---------------- ViT blocks ----------------
def transformer_encoder(x, embed_dim, num_heads, mlp_dim, dropout):
    # Norm -> MHA -> residual
    h = layers.LayerNormalization(epsilon=1e-6)(x)
    h = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=embed_dim // num_heads, dropout=dropout
    )(h, h)
    h = layers.Dropout(dropout)(h)
    x = layers.Add()([x, h])

    # Norm -> MLP -> residual
    h = layers.LayerNormalization(epsilon=1e-6)(x)
    h = layers.Dense(mlp_dim, activation=keras.activations.gelu)(h)
    h = layers.Dropout(dropout)(h)
    h = layers.Dense(embed_dim)(h)
    h = layers.Dropout(dropout)(h)
    x = layers.Add()([x, h])
    return x

def build_vit_model(
    image_size=IMG_SIZE,
    patch_size=PATCH_SIZE,
    embed_dim=EMBED_DIM,
    depth=DEPTH,
    num_heads=NUM_HEADS,
    mlp_dim=MLP_DIM,
    dropout=DROPOUT,
    num_classes=NUM_CLASSES
):
    inputs = keras.Input(shape=(image_size, image_size, 3))

    # (Tuỳ chọn) augmentation rất nhẹ nhàng
    aug = keras.Sequential([
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.05),
    ], name="augmentation")

    x = aug(inputs)

    # Chuẩn hoá về [0,1]
    x = layers.Rescaling(1./255)(x)

    # Patch embedding bằng Conv2D (kernel=stride=patch_size) -> (H/ps, W/ps, embed_dim)
    x = layers.Conv2D(
        filters=embed_dim, kernel_size=patch_size, strides=patch_size,
        padding="valid", name="patch_embedding"
    )(x)

    # Flatten thành chuỗi token: (batch, num_patches, embed_dim)
    x = layers.Reshape((-1, embed_dim))(x)  # -1 = (H/ps * W/ps) = NUM_PATCHES

    # Thêm positional embedding học được
    positions = tf.range(start=0, limit=NUM_PATCHES, delta=1)
    pos_embed = layers.Embedding(input_dim=NUM_PATCHES, output_dim=embed_dim, name="pos_embedding")(positions)
    x = x + pos_embed  # broadcasting trên batch

    # Các Transformer encoder blocks
    for i in range(depth):
        x = transformer_encoder(x, embed_dim, num_heads, mlp_dim, dropout)

    # Layer norm cuối + Pooling trung bình token
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)

    # Head phân loại
    x = layers.Dropout(dropout)(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    model = keras.Model(inputs=inputs, outputs=outputs, name="ViT_simple")
    return model

model = build_vit_model()
model.summary()

# ---------------- Compile ----------------
LR = 3e-4
try:
    optimizer = keras.optimizers.AdamW(learning_rate=LR, weight_decay=1e-4)
except Exception:
    # fallback nếu AdamW không có trong phiên bản TF/Keras của bạn
    optimizer = keras.optimizers.Adam(learning_rate=LR)

model.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

# ---------------- Train ----------------
os.makedirs("outputs_vit", exist_ok=True)
ckpt_path = "outputs_vit/best_vit.keras"

callbacks = [
    keras.callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True),
    keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=5, restore_best_weights=True),
    keras.callbacks.ReduceLROnPlateau(monitor="val_accuracy", factor=0.5, patience=2, min_lr=1e-6, verbose=1),
]

history = model.fit(
    ds_train,
    validation_data=ds_val,
    epochs=EPOCHS,
    callbacks=callbacks
)

# Lưu model cuối cùng (tùy chọn)
model.save("outputs_vit/final_vit.keras")
print("Saved:", ckpt_path)


2025-08-19 08:38:15.205846: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-19 08:38:15.216280: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755607095.229106    8163 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755607095.233014    8163 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755607095.242126    8163 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

Found 105145 files belonging to 5 classes.


I0000 00:00:1755607100.759963    8163 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 2281 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3050 Laptop GPU, pci bus id: 0000:01:00.0, compute capability: 8.6


Found 3511 files belonging to 5 classes.


Epoch 1/20


I0000 00:00:1755607114.565572    8403 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m3286/3286[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m549s[0m 163ms/step - accuracy: 0.2694 - loss: 1.5893 - val_accuracy: 0.2848 - val_loss: 1.7508 - learning_rate: 3.0000e-04
Epoch 2/20
[1m3286/3286[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m538s[0m 164ms/step - accuracy: 0.2927 - loss: 1.5173 - val_accuracy: 0.2828 - val_loss: 2.7140 - learning_rate: 3.0000e-04
Epoch 3/20
[1m3286/3286[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m539s[0m 164ms/step - accuracy: 0.3530 - loss: 1.3936 - val_accuracy: 0.3597 - val_loss: 2.6264 - learning_rate: 3.0000e-04
Epoch 4/20
[1m3286/3286[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m539s[0m 164ms/step - accuracy: 0.3805 - loss: 1.3444 - val_accuracy: 0.3426 - val_loss: 2.8372 - learning_rate: 3.0000e-04
Epoch 5/20
[1m3286/3286[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 162ms/step - accuracy: 0.3828 - loss: 1.3512
Epoch 5: ReduceLROnPlateau reducing learning rate to 0.0001500000071246177.
[1m3286/3286[0m [32m━━━━

KeyboardInterrupt: 