In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, DepthwiseConv2D, BatchNormalization, ReLU,
    GlobalAveragePooling2D, Dense, Add, Reshape, LayerNormalization,
    MultiHeadAttention, Flatten
)
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

# Parameters
input_shape = (128, 128, 3)
num_classes = 14
batch_size = 32
epochs = 20
train_dir = "c:/MyData/train"
val_dir = "c:/MyData/val"

# Depthwise Separable Conv Block + GAP Skip
def dsc_block_with_skip(x_input, filters):
    # Main path
    x = DepthwiseConv2D(kernel_size=3, padding='same', use_bias=False)(x_input)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(filters, kernel_size=1, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)

    # GAP skip path
    gap = GlobalAveragePooling2D()(x_input)
    gap = Dense(filters, activation='relu')(gap)
    gap = tf.expand_dims(tf.expand_dims(gap, 1), 1)
    gap = tf.tile(gap, [1, tf.shape(x)[1], tf.shape(x)[2], 1])

    # Residual add
    x = Add()([x, gap])
    return x

# Patch Embedding for ViT
def patch_embedding(x, patch_size=8, embed_dim=256):
    x = Conv2D(embed_dim, patch_size, strides=patch_size, padding='valid')(x)
    x = Reshape((-1, embed_dim))(x)
    return x

# Vision Transformer block
def vit_block(x, num_heads=2):
    x_norm = LayerNormalization()(x)
    attn_output = MultiHeadAttention(num_heads=num_heads, key_dim=x.shape[-1])(x_norm, x_norm)
    x = Add()([x, attn_output])
    x_norm2 = LayerNormalization()(x)
    mlp_output = Dense(x.shape[-1], activation='relu')(x_norm2)
    x = Add()([x, mlp_output])
    return x

# Full model
def build_model():
    inputs = Input(shape=input_shape)

    x1 = dsc_block_with_skip(inputs, 16)
    x2 = dsc_block_with_skip(x1, 32)
    x3 = dsc_block_with_skip(x2, 64)
    x4 = dsc_block_with_skip(x3, 128)
    x5 = dsc_block_with_skip(x4, 256)

    # ViT
    vit_input = patch_embedding(x5)
    vit_out = vit_block(vit_input)

    # Classifier head
    x = Flatten()(vit_out)
    x = Dense(1024, activation='relu')(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return Model(inputs, outputs)

# Instantiate and compile model
model = build_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

# Data generators
train_gen = ImageDataGenerator(rescale=1./255, zoom_range=0.2, horizontal_flip=True)
val_gen = ImageDataGenerator(rescale=1./255)

train_data = train_gen.flow_from_directory(train_dir, target_size=(128, 128), batch_size=batch_size, class_mode='categorical')
val_data = val_gen.flow_from_directory(val_dir, target_size=(128, 128), batch_size=batch_size, class_mode='categorical')

# Training
history = model.fit(train_data, epochs=epochs, validation_data=val_data)

# Evaluation
loss, acc = model.evaluate(val_data)
print(f"\n Validation Accuracy: {acc*100:.2f}%, Loss: {loss:.4f}")

# Plot accuracy and loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label="Train Acc")
plt.plot(history.history['val_accuracy'], label="Val Acc")
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label="Train Loss")
plt.plot(history.history['val_loss'], label="Val Loss")
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.tight_layout()
plt.show()

# Confusion matrix + classification report
val_labels = val_data.classes
class_names = list(val_data.class_indices.keys())
predictions = model.predict(val_data)
predicted_classes = np.argmax(predictions, axis=1)

print("\n Classification Report:")
print(classification_report(val_labels, predicted_classes, target_names=class_names))

cm = confusion_matrix(val_labels, predicted_classes)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(xticks_rotation=90, cmap='Blues')
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()
