In [None]:
import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Dense, Dropout, 
                                   BatchNormalization, Input, GlobalAveragePooling2D, 
                                   Concatenate, LeakyReLU, Reshape, Conv2DTranspose, 
                                   Flatten, UpSampling2D)
from tensorflow.keras.utils import Sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K
import numpy as np
import os
import cv2
import random
import albumentations as A
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from tensorflow.keras.optimizers import AdamW
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import seaborn as sns

# Introduce randomness
random.seed(42)
np.random.seed(42)
tf.random.set_seed(42)

os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  

# === Configuration ===
config = {
    "epochs": 1,
    "batch_size": 100,
    "initial_lr": 0.001,
    "gpu_memory_limit": 10,  # in GB
    "target_size": (320, 480),  # 2:3 ratio (width, height)
    "input_shape": (480, 320, 3), # (height, width, channels) for Keras
    "data_path": "Dataset/train_images",
    "csv_path": "processed_data/cleaned_metadata_short.csv", # _short is for testing only
    "train_set_csv": "training8_rgb/training8_rgb_train_set.csv",
    "val_set_csv": "training8_rgb/training8_rgb_validation_set.csv",
    "history_csv": "training8_rgb/training8_rgb_history.csv",
    "best_model": "training8_rgb/training8_rgb_best_model.keras",
    "label_encoder_path": "training8_rgb/training8_rgb_label_encoder.npy",
    "gan_epochs": 20,
    "latent_dim": 100,
    "gan_batch_size": 32,
    "synthetic_ratio": 0.3 
}

# === GPU Setup ===
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        policy = mixed_precision.Policy('float32')
        mixed_precision.set_global_policy(policy)
        
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.optimizer.set_jit(True)
        tf.config.threading.set_intra_op_parallelism_threads(8)
        tf.config.threading.set_inter_op_parallelism_threads(4)
    except RuntimeError as e:
        print(e)
        
# === Memory Management ===
def calculate_max_batch_size(model, input_shape, gpu_mem=24, default_batch=32, is_use_config_batch_size=False):
    """Improved batch size calculator with error handling"""
    if is_use_config_batch_size:
        return default_batch
    try:
        params = model.count_params()
        
        last_dense = None
        for layer in reversed(model.layers):
            if isinstance(layer, tf.keras.layers.Dense):
                last_dense = layer
                if layer.name == 'features':  
                    break
        
        if last_dense is None:
            raise ValueError("No Dense layer found in model!")
        
        # Memory per sample = weights + activations (in GB)
        per_sample = (
            (params * 4) +                 
            (np.prod(input_shape) * last_dense.units * 4)  
        ) / (1024 ** 3)
        
        # Max batch size with 3GB safety margin
        max_batch = int((gpu_mem - 3) / per_sample)
        return min(256, max_batch)  
    
    except Exception as e:
        print(f"Warning: Batch size estimation failed, using default={default_batch}. Error: {e}")
        return default_batch

def cleanup_gpu_memory():
    """Force clear GPU memory"""
    K.clear_session()
    tf.compat.v1.reset_default_graph()
    if tf.config.list_physical_devices('GPU'):
        try:
            for gpu in tf.config.list_physical_devices('GPU'):
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError:
            pass


In [13]:
# === Data Pipeline ===
def load_and_preprocess_data(random_state=42, save_splits=True):
    """Load and split data with fixed random state for reproducibility"""
    df = pd.read_csv(config["csv_path"])
    
    le = LabelEncoder()
    df['label_encoded'] = le.fit_transform(df['label'])
    print(f"Label classes: {le.classes_}")
    
    with open(config['label_encoder_path'], 'wb') as f:
        np.save(f, le.classes_)
    
    train_df, val_df = train_test_split(
        df, 
        test_size=0.2, 
        stratify=df['label'],
        random_state=random_state,
    )
    
    if save_splits:
        train_df.to_csv(config['train_set_csv'], index=False)
        val_df.to_csv(config['val_set_csv'], index=False)
    
    return train_df, val_df, le

def build_gan(generator_input_shape=(100,), image_shape=(480, 320, 3)):
    # Generator
    generator = tf.keras.Sequential([
        Dense(256 * 120 * 80, use_bias=False, input_shape=generator_input_shape),
        BatchNormalization(),
        LeakyReLU(0.2),
        Reshape((120, 80, 256)),
        
        Conv2DTranspose(128, (5,5), strides=1, padding='same', use_bias=False),
        BatchNormalization(),
        LeakyReLU(0.2),
        
        Conv2DTranspose(64, (5,5), strides=2, padding='same', use_bias=False),
        BatchNormalization(),
        LeakyReLU(0.2),
        
        Conv2DTranspose(3, (5,5), strides=2, padding='same', 
                       activation='tanh', use_bias=False)
    ], name='generator')

    # Discriminator
    discriminator = tf.keras.Sequential([
        Conv2D(64, (5,5), strides=2, padding='same', 
               input_shape=image_shape),
        LeakyReLU(0.2),
        Dropout(0.3),
        
        Conv2D(128, (5,5), strides=2, padding='same'),
        LeakyReLU(0.2),
        Dropout(0.3),
        
        Flatten(),
        Dense(1, activation='sigmoid')
    ], name='discriminator')

    return generator, discriminator

# === Model Architecture ===
def create_gpu_optimized_model(input_shape, num_classes):    
    inputs = Input(shape=input_shape, dtype=tf.float32) 
     
    # Initial feature extraction
    x = Conv2D(96, (7,7), strides=2, activation='relu', padding='same')(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3,3), strides=2)(x)
    
    # Intermediate layers
    x = Conv2D(256, (5,5), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3,3), strides=2)(x)
    
    # Parallel paths
    branch1 = Conv2D(384, (3,3), activation='relu', padding='same')(x)
    branch2 = Conv2D(384, (3,3), dilation_rate=2, activation='relu', padding='same')(x)
    x = Concatenate()([branch1, branch2])
    
    # Final classification head
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu', name='features')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax', dtype=tf.float32)(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)

class RiceDataGenerator(Sequence):
    def __init__(self, df, base_path, batch_size=32, target_size={config['target_size']}, shuffle=False, debug=False, **kwargs):
        super().__init__(**kwargs)
        self.df = df.reset_index(drop=True)
        self.base_path = base_path
        self.batch_size = batch_size
        self.target_size = target_size  
        self.shuffle = shuffle
        self.debug = debug
        self.indices = np.arange(len(df))
        
        self.aug = A.Compose([
            # A.RandomRotate90(),
            # A.HorizontalFlip(),
            # A.VerticalFlip(),
            # A.Transpose(),
            # A.RandomBrightnessContrast(p=0.5),
            # A.HueSaturationValue(p=0.5),
            # A.CLAHE(p=0.5),
            A.Resize(width=self.target_size[0], height=self.target_size[1]),
        ])
        
        if shuffle:
            np.random.shuffle(self.indices)
            
        if self.debug:
            self._visualize_samples()    
            

    def _visualize_samples(self):
        """Visualize first 2 samples after augmentation"""
        
        for i in range(min(1, len(self.df))):
            try:
                row = self.df.iloc[i]
                img = self._load_image(row['image_id'], row['label'])
                augmented = self.aug(image=img)
                
                plt.figure(figsize=(12, 6))
                
                # original
                plt.subplot(1, 2, 1)
                plt.imshow(img)
                plt.title(f"Original\nShape: {img.shape}")
                
                # augmented
                plt.subplot(1, 2, 2)
                plt.imshow(augmented['image'])
                plt.title(f"Augmented\nShape: {augmented['image'].shape}")
                
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"Visualization failed for {row['image_id']}: {str(e)}")
    
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    def _load_image(self, image_id, label, suffix=''): #_nipy_spectral
        img_path = os.path.join(
            self.base_path,
            label,
            f"{os.path.splitext(image_id)[0]}{suffix}.jpg"
        )
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found at {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, self.target_size)  
        return img
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_df = self.df.iloc[batch_indices]
        
        X = np.zeros((len(batch_df), self.target_size[1], self.target_size[0], 3), dtype=np.float32) #(batch, height, width, channels)
        y = np.zeros((len(batch_df),), dtype=np.int32)
        
        for i, (_, row) in enumerate(batch_df.iterrows()):
            try:
                img = self._load_image(row['image_id'], row['label'])
                augmented = self.aug(image=img)
                X[i] = augmented['image'] / 255.0
                y[i] = row['label_encoded']
            except Exception as e:
                print(f"Error loading {row['image_id']}: {str(e)}")
                X[i] = np.zeros((self.target_size[1], self.target_size[0], 3), dtype=np.float32) #(batch, height, width, channels)
                y[i] = -1
                
        valid = y != -1
        return X[valid], y[valid]
    
class AugmentedRiceGenerator(RiceDataGenerator):
    def __init__(self, gan_generator, **kwargs):
        super().__init__(**kwargs)
        self.gan = gan_generator
        self.latent_dim = 100
        
    def __getitem__(self, idx):
        X_real, y_real = super().__getitem__(idx)
        
        # Generate synthetic samples
        noise = np.random.normal(0, 1, (len(X_real), self.latent_dim))
        X_fake = self.gan.predict(noise, verbose=0)
        X_fake = (X_fake + 1) / 2  # Scale from [-1,1] to [0,1]
        
        # Blend real and synthetic data (50-50 mix)
        X_mixed = np.concatenate([X_real, X_fake])
        y_mixed = np.concatenate([y_real, y_real])  # Use same labels
        
        return X_mixed, y_mixed

def train_gan(train_gen, epochs=50, latent_dim=100):
    generator, discriminator = build_gan()
    
    # Combined model
    discriminator.compile(
        optimizer=AdamW(learning_rate=0.0002, beta_1=0.5),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
    
    discriminator.trainable = False
    gan_input = Input(shape=(latent_dim,))
    gan_output = discriminator(generator(gan_input))
    gan = tf.keras.Model(gan_input, gan_output)
    gan.compile(
        optimizer=AdamW(learning_rate=0.0002, beta_1=0.5),
        loss='binary_crossentropy'
    )
    
    # Training
    for epoch in range(epochs):
        for batch_idx, (real_images, _) in enumerate(train_gen):
            # Train discriminator
            noise = np.random.normal(0, 1, (real_images.shape[0], latent_dim))
            fake_images = generator.predict(noise, verbose=0)
            
            d_loss_real = discriminator.train_on_batch(real_images, np.ones((real_images.shape[0], 1)))
            d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((fake_images.shape[0], 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # Train generator
            noise = np.random.normal(0, 1, (real_images.shape[0], latent_dim))
            g_loss = gan.train_on_batch(noise, np.ones((real_images.shape[0], 1)))
            
        print(f"Epoch {epoch+1}, D Loss: {d_loss[0]:.4f}, G Loss: {g_loss:.4f}")
    
    return generator

In [14]:
# === Training ===
def train():
    cleanup_gpu_memory()
    
    try:
        # === 1. Data Loading ===
        train_df, val_df, le = load_and_preprocess_data(random_state=42)
        num_classes = len(le.classes_)
        print("Classes: ", num_classes)
        
        # === 2. Model Initialization ===
        input_shape = config["input_shape"]
        model = create_gpu_optimized_model(input_shape, num_classes)
        
        # === 3. Memory Optimization ===
        optimized_batch_size = calculate_max_batch_size(
            model, 
            input_shape=config["input_shape"],
            gpu_mem=config["gpu_memory_limit"],
            default_batch=config["batch_size"],
            is_use_config_batch_size=False
        )
        
        print(f"\n=== Training Configuration ===")
        print(f"Batch size: {optimized_batch_size}")
        print(f"Input size: {config['target_size']}")
        print(f"Classes: {num_classes}")
        print(f"GPU Memory: {config['gpu_memory_limit']}GB\n")
        
        # === 4. GAN Training ===
        print("=== Phase 1: Training GAN for Augmentation ===")
        base_train_gen = RiceDataGenerator(
            df=train_df,
            base_path=config["data_path"],
            batch_size=min(32, optimized_batch_size),  # Smaller batches for GAN
            target_size=config["target_size"],
            shuffle=True
        )
        
        gan_generator = train_gan(base_train_gen, epochs=20)
        
        # === 5. Augmented Data Pipeline ===
        print("=== Phase 2: Training CNN with GAN Augmentation ===")
        augmented_train_gen = AugmentedRiceGenerator(
            gan_generator=gan_generator,
            df=train_df,
            base_path=config["data_path"],
            batch_size=optimized_batch_size,
            target_size=config["target_size"],
            shuffle=True
        )
        
        val_gen = RiceDataGenerator(
            df=val_df,
            base_path=config["data_path"],
            batch_size=optimized_batch_size,
            target_size=config["target_size"],
            shuffle=False
        )
        
        # === 6. Model Compilation ===
        model.compile(
            optimizer=AdamW(learning_rate=config["initial_lr"]),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy',
                   tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top3_acc')]
        )
        
        # === 7. Training Loop ===
        history = model.fit(
            augmented_train_gen,
            validation_data=val_gen,
            epochs=config["epochs"],
            callbacks=[
                tf.keras.callbacks.EarlyStopping(
                    monitor='val_accuracy',
                    patience=10,  # Increased patience for GAN-augmented training
                    mode='max',
                    restore_best_weights=True
                ),
                tf.keras.callbacks.ModelCheckpoint(
                    config["best_model"],
                    monitor='val_accuracy',
                    save_best_only=True,
                    save_weights_only=False
                ),
                tf.keras.callbacks.ReduceLROnPlateau(
                    monitor='val_accuracy',
                    factor=0.5,
                    patience=3,
                    verbose=1,
                    mode='max'
                ),
                tf.keras.callbacks.CSVLogger(config["history_csv"])
            ]
        )
        
        # === 8. Visualization ===
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'], label='Train Accuracy')
        plt.plot(history.history['val_accuracy'], label='Val Accuracy')
        plt.axhline(y=max(history.history['val_accuracy']), color='r', linestyle='--')
        plt.title('Accuracy Curves')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Train Loss')
        plt.plot(history.history['val_loss'], label='Val Loss')
        plt.title('Loss Curves')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_metrics.png')
        
        return model, history
        
    except Exception as e:
        print(f"Training failed: {e}")
        cleanup_gpu_memory()
        raise

In [15]:
def evaluate_saved_model(model_path, use_val_set=True):
    """Deterministic evaluation with debug checks"""
    try:
        # Load model 
        with open(config['label_encoder_path'], 'rb') as f:
            classes = np.load(f, allow_pickle=True)
        le = LabelEncoder()
        le.classes_ = classes

        model = tf.keras.models.load_model(model_path, compile=False) 
            
        cleanup_gpu_memory()
        optimized_batch_size = calculate_max_batch_size(
                                    model, 
                                    input_shape=config["input_shape"],
                                    gpu_mem=config["gpu_memory_limit"],
                                    default_batch=config["batch_size"],
                                    is_use_config_batch_size=False
                                )
        
        model.compile(
            optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=[
                tf.keras.metrics.SparseCategoricalAccuracy(name='acc'),
                tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top3_acc')
            ]
        )

        eval_df = pd.read_csv(config["val_set_csv"]) if use_val_set else load_and_preprocess_data(save_splits=False)[1]
        print(f"Evaluating on {len(eval_df)} samples")

        eval_gen = RiceDataGenerator(
            df=eval_df,
            base_path=config["data_path"],
            batch_size=optimized_batch_size,  
            target_size=config["target_size"],
            shuffle=False,  
            debug=True
        )

        # Inspect first batch
        x_test, y_test = eval_gen[0]
        print(f"\n[DEBUG] First batch - X shape: {x_test.shape}, y shape: {y_test.shape}")
        print(f"[DEBUG] Sample label: {y_test[0]} -> {le.classes_[y_test[0]]}")

        # Standard evaluation
        results = model.evaluate(eval_gen, verbose=1, return_dict=True)
        print("\n[METRICS] Evaluation Results:", results)

        # Collect all predictions
        y_true, y_pred = [], []
        for i in range(len(eval_gen)):
            x, y = eval_gen[i]
            y_true.extend(y)
            y_pred.extend(model.predict(x, verbose=0).argmax(axis=1))  # Get class indices

        y_true = np.array(y_true)
        y_pred = np.array(y_pred)
        assert len(y_true) == len(y_pred), "Label/prediction length mismatch!"

        # Print sample predictions
        print("\n[PREDICTION SAMPLES]")
        for i in range(min(5, len(y_true))):
            print(f"True: {le.classes_[y_true[i]]} ({y_true[i]}) | Pred: {le.classes_[y_pred[i]]} ({y_pred[i]})")

        # Confusion Matrix Analysis
        print("\n[CONFUSION MATRIX PARAMETERS]")
        print(f"- Classes: {le.classes_}")  # Class names from LabelEncoder
        print(f"- Total samples: {len(y_true)}")
        print(f"- Batch size: {optimized_batch_size} (affects matrix granularity)")
        print(f"- Most confused classes: Will be visible in plot")

        # Plot both normalized and raw counts
        plot_confusion_matrix(
            y_true, 
            y_pred, 
            classes=le.classes_,
            normalize=True,
            title='Normalized Confusion Matrix (%)'
        )

        plot_confusion_matrix(
            y_true,
            y_pred,
            classes=le.classes_,
            normalize=False,
            title='Confusion Matrix (Counts)'
        )

        # Additional metrics
        from sklearn.metrics import classification_report
        print("\n[CLASSIFICATION REPORT]")
        print(classification_report(y_true, y_pred, target_names=le.classes_))

        return results

    except Exception as e:
        print(f"Evaluation failed: {e}")
        raise
    
def plot_confusion_matrix(y_true, y_pred, classes, normalize=False, title=None):
    """
    Plots the confusion matrix.
    
    Args:
        y_true: True labels
        y_pred: Predicted labels
        classes: List of class names
        normalize: Whether to normalize the matrix
        title: Plot title
    """
    cm = confusion_matrix(y_true, y_pred)
    
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
    else:
        fmt = 'd'
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt=fmt, xticklabels=classes, yticklabels=classes,
                cmap='Blues', cbar=False)
    
    plt.title(title or 'Confusion Matrix')
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()


In [None]:
# === Execution Options ===
if __name__ == "__main__":
    model, model_history = train()
    # model.save(config["best_model"])
    evaluate_saved_model(config["best_model"], use_val_set=True)

Label classes: ['bacterial_leaf_blight' 'bacterial_leaf_streak' 'blast' 'normal' 'tungro']
Classes:  5

=== Training Configuration ===
Batch size: -2
Input size: (320, 480)
Classes: 5
GPU Memory: 1GB

=== Phase 1: Training GAN for Augmentation ===


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
E0000 00:00:1746958408.763736  753466 meta_optimizer.cc:967] remapper failed: INVALID_ARGUMENT: Mutation::Apply error: fanout 'gradient_tape/functional_3_1/discriminator_1/leaky_re_lu_3_1/LeakyRelu/LeakyReluGrad' exist for missing node 'functional_3_1/discriminator_1/conv2d_4_1/BiasAdd'.
