In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from sklearn.utils.class_weight import compute_class_weight

IMAGE_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 40
LEARNING_RATE = 1e-4
dataset_dir = "/home/anjalit/Crop_Disease_Detection/Crop__Disease"
checkpoint_dir = "/home/anjalit/Crop_Disease_Detection/Crop__Disease/checkpoints"

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

def get_data_generators(dataset_dir, image_size, batch_size, seed=42):
    train_datagen = ImageDataGenerator(
        rescale=1.0 / 255,
        validation_split=0.2,
        rotation_range=30,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        fill_mode="nearest",
        brightness_range=[0.7, 1.3],
        channel_shift_range=50.0,
    )

    val_datagen = ImageDataGenerator(
        rescale=1.0 / 255,
        validation_split=0.2,
    )

    train_generator = train_datagen.flow_from_directory(
        directory=dataset_dir,
        target_size=(image_size, image_size),
        batch_size=batch_size,
        class_mode="categorical",
        subset="training",
        shuffle=True,
        seed=seed,
    )

    val_generator = val_datagen.flow_from_directory(
        directory=dataset_dir,
        target_size=(image_size, image_size),
        batch_size=batch_size,
        class_mode="categorical",
        subset="validation",
        shuffle=False,
        seed=seed,
    )

    return train_generator, val_generator

def get_class_weights(generator):
    class_weights = compute_class_weight(
        class_weight="balanced",
        classes=np.unique(generator.classes),
        y=generator.classes
    )
    return dict(enumerate(class_weights))

class TransformerBlock(layers.Layer):
    def __init__(self, d_model, num_heads, mlp_dim):
        super(TransformerBlock, self).__init__()
        self.layernorm1 = layers.LayerNormalization()
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        self.layernorm2 = layers.LayerNormalization()
        self.mlp = keras.Sequential([
            layers.Dense(mlp_dim, activation=tf.nn.gelu),
            layers.Dense(d_model)
        ])

    def call(self, x, training):
        attn_output = self.mha(x, x, x)
        out1 = self.layernorm1(x + attn_output)
        mlp_output = self.mlp(out1)
        return self.layernorm2(out1 + mlp_output)

class MultiScaleVisionTransformer(tf.keras.Model):
    def __init__(self, image_size=224, patch_sizes=[16, 32], num_layers=8, d_model=256, num_heads=8, mlp_dim=512, num_classes=7):
        super(MultiScaleVisionTransformer, self).__init__()
        self.image_size = image_size
        self.patch_sizes = patch_sizes
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_model = d_model
        self.mlp_dim = mlp_dim
        self.num_classes = num_classes

        self.patch_projections = []
        for size in patch_sizes:
            self.patch_projections.append(layers.Conv2D(
                filters=self.d_model,
                kernel_size=size,
                strides=size,
                padding='valid'
            ))

        self.transformer_blocks = [TransformerBlock(self.d_model, self.num_heads, self.mlp_dim) for _ in range(self.num_layers)]
        self.classifier_head = layers.Dense(self.num_classes, activation='softmax')

    def call(self, x, training=False):
        patch_embeddings = []
        for proj in self.patch_projections:
            patches = proj(x)
            patches = tf.reshape(patches, [tf.shape(x)[0], -1, self.d_model])
            patch_embeddings.append(patches)

        x = tf.concat(patch_embeddings, axis=1)
        for block in self.transformer_blocks:
            x = block(x, training=training)

        x = tf.reduce_mean(x, axis=1)
        return self.classifier_head(x)

def train_model():
    strategy = tf.distribute.get_strategy()

    with strategy.scope():
        train_generator, val_generator = get_data_generators(dataset_dir, IMAGE_SIZE, BATCH_SIZE)
        num_classes = len(train_generator.class_indices)

        model = MultiScaleVisionTransformer(
            image_size=IMAGE_SIZE,
            patch_sizes=[16, 32],
            num_layers=8,
            d_model=256,
            num_heads=8,
            mlp_dim=512,
            num_classes=num_classes
        )
        model.build((None, IMAGE_SIZE, IMAGE_SIZE, 3))

        optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
        model.compile(
            optimizer=optimizer,
            loss='categorical_crossentropy',
            metrics=['accuracy', keras.metrics.Precision(), keras.metrics.Recall()]
        )

        x_batch, y_batch = next(train_generator)

    class_weights = get_class_weights(train_generator)

    class SaveModelCallback(tf.keras.callbacks.Callback):
        def on_epoch_end(self, epoch, logs=None):
            model.save(os.path.join(checkpoint_dir, f"saved_model_epoch_{epoch + 1}"), save_format="tf")

    history = model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=EPOCHS,
        class_weight=class_weights,
        callbacks=[
            keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
            keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=1e-6),
            SaveModelCallback()
        ]
    )

    return model, history, val_generator

model, history, val_generator = train_model()

model.save("multi_scale_vit_model.tf")
evaluation = model.evaluate(val_generator)

def plot_training_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1.plot(history.history['accuracy'], label='Train')
    ax1.plot(history.history['val_accuracy'], label='Validation')
    ax1.set_title('Model Accuracy')
    ax1.legend()
    ax2.plot(history.history['loss'], label='Train')
    ax2.plot(history.history['val_loss'], label='Validation')
    ax2.set_title('Model Loss')
    ax2.legend()
    plt.show()

plot_training_history(history)
