In [15]:
import tensorflow as tf
from tensorflow.keras import mixed_precision
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Dense, Dropout, 
                                   BatchNormalization, Input, GlobalAveragePooling2D, 
                                   Concatenate)
from tensorflow.keras.utils import Sequence
from tensorflow.keras import backend as K
import numpy as np
import os
import cv2
import albumentations as A
from sklearn.preprocessing import LabelEncoder
import pandas as pd
from tensorflow.keras.optimizers import AdamW
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
# === Configuration ===
config = {
    "data_path": "Dataset/preprocessed_images",
    "csv_path": "processed_data/cleaned_metadata.csv",
    "target_size": (480, 320),  # 3:2 ratio
    "epochs": 1,
    "initial_lr": 0.001,
    "gpu_memory_limit": 10,  # GB
    "batch_size": 100,
}

# === GPU Setup ===
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Enable mixed precision
        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_global_policy(policy)
        
        # Memory growth and optimization
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.optimizer.set_jit(True)
        tf.config.threading.set_intra_op_parallelism_threads(8)
        tf.config.threading.set_inter_op_parallelism_threads(4)
    except RuntimeError as e:
        print(e)

# === Model Architecture ===
def create_gpu_optimized_model(input_shape, num_classes):
    inputs = Input(shape=input_shape, dtype=tf.float16)
    
    # Initial feature extraction
    x = Conv2D(96, (7,7), strides=2, activation='relu', padding='same')(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3,3), strides=2)(x)
    
    # Intermediate layers
    x = Conv2D(256, (5,5), activation='relu', padding='same')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3,3), strides=2)(x)
    
    # Parallel paths
    branch1 = Conv2D(384, (3,3), activation='relu', padding='same')(x)
    branch2 = Conv2D(384, (3,3), dilation_rate=2, activation='relu', padding='same')(x)
    x = Concatenate()([branch1, branch2])
    
    # Final classification head
    x = GlobalAveragePooling2D()(x)
    x = Dense(512, activation='relu', name='features')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax', dtype=tf.float32)(x)
    
    return tf.keras.Model(inputs=inputs, outputs=outputs)

# === Memory Management ===
def calculate_max_batch_size(model, input_shape=(480, 320, 3), gpu_mem=24):
    """Improved batch size calculator"""
    try:
        # Estimate memory usage per sample
        params = model.count_params()
        
        # Find last dense layer
        last_dense = next(layer for layer in reversed(model.layers) 
                     if isinstance(layer, Dense) and layer.name == 'features')
        
        # Memory estimation (conservative)
        per_sample = (params * 4 +  # 4 bytes per parameter
                     np.prod(input_shape) * last_dense.units * 4) / (1024**3)  # GB
        
        # Calculate max batch with 3GB buffer
        max_batch = int((gpu_mem - 3) / per_sample)
        
        return max(16, min(256, max_batch))
    
    except Exception as e:
        print(f"Error calculating batch size: {e}")
        return 32  # Fallback value

def cleanup_gpu_memory():
    """Force clear GPU memory"""
    K.clear_session()
    tf.compat.v1.reset_default_graph()
    if tf.config.list_physical_devices('GPU'):
        try:
            for gpu in tf.config.list_physical_devices('GPU'):
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError:
            pass

# === Data Pipeline ===
def load_and_preprocess_data(random_state=42, save_splits=True):
    """Load and split data with fixed random state for reproducibility"""
    df = pd.read_csv(config["csv_path"])
    
    le = LabelEncoder()
    df['label_encoded'] = le.fit_transform(df['label'])
    
    # Save label encoder classes
    with open('training6_label_encoder.npy', 'wb') as f:
        np.save(f, le.classes_)
    
    # Split data
    train_df, val_df = train_test_split(
        df, 
        test_size=0.15, 
        stratify=df['label'],
        random_state=random_state
    )
    
    # Save splits for later reference
    if save_splits:
        train_df.to_csv('training6_train_set.csv', index=False)
        val_df.to_csv('training6_validation_set.csv', index=False)
    
    return train_df, val_df, le

class RiceDataGenerator(Sequence):
    def __init__(self, df, base_path, batch_size=32, target_size=(480, 320), shuffle=True, debug=False, **kwargs):
        super().__init__(**kwargs)
        self.df = df.reset_index(drop=True)
        self.base_path = base_path
        self.batch_size = batch_size
        self.target_size = target_size  
        self.shuffle = shuffle
        self.debug = debug
        self.indices = np.arange(len(df))
        
        self.aug = A.Compose([
            # A.RandomRotate90(),
            # A.HorizontalFlip(),
            # A.VerticalFlip(),
            # A.Transpose(),
            # A.RandomBrightnessContrast(p=0.5),
            # A.HueSaturationValue(p=0.5),
            # A.CLAHE(p=0.5),
            A.Resize(width=self.target_size[0], height=self.target_size[1]),
        ])
        
        if shuffle:
            np.random.shuffle(self.indices)
            
        if self.debug:
            self._visualize_samples()    
            

    def _visualize_samples(self):
        """Visualize first 2 samples after augmentation"""
        
        for i in range(min(2, len(self.df))):
            try:
                row = self.df.iloc[i]
                img = self._load_image(row['image_id'], row['label'])
                augmented = self.aug(image=img)
                
                plt.figure(figsize=(12, 6))
                
                # Show original
                plt.subplot(1, 2, 1)
                plt.imshow(img)
                plt.title(f"Original\nShape: {img.shape}")
                
                # Show augmented
                plt.subplot(1, 2, 2)
                plt.imshow(augmented['image'])
                plt.title(f"Augmented\nShape: {augmented['image'].shape}")
                
                plt.tight_layout()
                plt.show()
                
            except Exception as e:
                print(f"Visualization failed for {row['image_id']}: {str(e)}")
    
    def __len__(self):
        return int(np.ceil(len(self.df) / self.batch_size))
    
    @property
    def num_batches(self):
        return len(self)
    
    def _load_image(self, image_id, label, suffix='green'):
        img_path = os.path.join(
            self.base_path,
            label,
            f"{os.path.splitext(image_id)[0]}_{suffix}.jpg"
        )
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(f"Image not found at {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Resize using (width, height)
        img = cv2.resize(img, self.target_size)  
        return img
    
    def __getitem__(self, idx):
        batch_indices = self.indices[idx*self.batch_size:(idx+1)*self.batch_size]
        batch_df = self.df.iloc[batch_indices]
        
        # Initialize with correct dimensions (height, width, channels)
        X = np.zeros((len(batch_df), self.target_size[1], self.target_size[0], 3), dtype=np.float32)
        y = np.zeros((len(batch_df),), dtype=np.int32)
        
        for i, (_, row) in enumerate(batch_df.iterrows()):
            try:
                img = self._load_image(row['image_id'], row['label'])
                # No need to transpose - we'll match model to data shape
                augmented = self.aug(image=img)
                X[i] = augmented['image'] / 255.0
                y[i] = row['label_encoded']
            except Exception as e:
                print(f"Error loading {row['image_id']}: {str(e)}")
                X[i] = np.zeros((self.target_size[1], self.target_size[0], 3), dtype=np.float32)
                y[i] = -1
                
        valid = y != -1
        return X[valid], y[valid]

    def on_epoch_end(self):
        """Shuffle indices after each epoch"""
        if self.shuffle:
            np.random.shuffle(self.indices)

# === Training ===
def train():
    cleanup_gpu_memory()
    
    try:
        # Load data with fixed random state
        train_df, val_df, le = load_and_preprocess_data(random_state=42)
        num_classes = len(le.classes_)
        
        # Create model
        input_shape = (config["target_size"][1], config["target_size"][0], 3)
        model = create_gpu_optimized_model(input_shape, num_classes)
        
        # Calculate batch size
        cleanup_gpu_memory()
        batch_size = calculate_max_batch_size(model)
        
        print(f"\n=== Training Configuration ===")
        print(f"Batch size: {batch_size}")
        print(f"Input size: {config['target_size']}")
        print(f"Classes: {num_classes}")
        print(f"GPU Memory: {config['gpu_memory_limit']}GB\n")
        print(f"Model input shape: {model.input_shape}")
        
        # Create generators
        train_gen = RiceDataGenerator(
            df=train_df,
            base_path=config["data_path"],
            batch_size=config["batch_size"],
            target_size=config["target_size"],
            shuffle=True
        )
        
        val_gen = RiceDataGenerator(
            df=val_df,
            base_path=config["data_path"],
            batch_size=config["batch_size"],
            target_size=config["target_size"],
            shuffle=False
        )
        
        # Compile model
        model.compile(
            optimizer=AdamW(learning_rate=config["initial_lr"], weight_decay=1e-4),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )
        
        # Train
        history = model.fit(
            train_gen,
            validation_data=val_gen,
            epochs=config["epochs"],
            callbacks=[
                tf.keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True),
                tf.keras.callbacks.ModelCheckpoint(
                    'training6_best_model.h5',
                    save_best_only=True,
                    monitor='val_accuracy'
                ),
                tf.keras.callbacks.CSVLogger('training6_history.csv')
            ]
        )
        
        return model, history
        
    except Exception as e:
        print(f"Training failed: {e}")
        cleanup_gpu_memory()
        raise

In [None]:
def evaluate_saved_model(model_path, use_val_set=True):
    """Evaluate a saved model with comprehensive metrics"""
    try:
        # Load label encoder
        with open('training6_label_encoder.npy', 'rb') as f:
            classes = np.load(f, allow_pickle=True)
        le = LabelEncoder()
        le.classes_ = classes
        
        # Load model with custom objects
        custom_objects = {'AdamW': tf.keras.optimizers.AdamW}
        with tf.keras.utils.custom_object_scope(custom_objects):
            model = tf.keras.models.load_model(model_path)
        
        # Load appropriate dataset
        if use_val_set:
            print("\nUsing saved validation set for evaluation")
            eval_df = pd.read_csv('training6_validation_set.csv')
        else:
            print("\nUsing new test data for evaluation")
            _, eval_df, _ = load_and_preprocess_data(save_splits=False)
        
        # Create evaluation generator
        eval_gen = RiceDataGenerator(
            df=eval_df,
            base_path=config["data_path"],
            batch_size=config["batch_size"],
            target_size=config["target_size"],
            shuffle=False
        )
        
        # Recompile with proper metrics
        model.compile(
            optimizer=model.optimizer,
            loss='sparse_categorical_crossentropy',
            metrics=[
                'accuracy',
                tf.keras.metrics.SparseAUC(name='auc'),
                tf.keras.metrics.SparsePrecision(name='precision'),
                tf.keras.metrics.SparseRecall(name='recall'),
                tf.keras.metrics.SparseTopKCategoricalAccuracy(k=3, name='top3_accuracy')
            ]
        )
        
        # Evaluate
        print("\n=== Evaluating Model ===")
        results = model.evaluate(eval_gen, verbose=1)
        
        # Display results
        print("\n=== Evaluation Results ===")
        for name, value in zip(model.metrics_names, results):
            print(f"{name:15}: {value:.4f}")
        
        # Generate predictions
        print("\nGenerating predictions...")
        y_true = []
        y_pred = []
        
        for i in range(len(eval_gen)):
            x, y = eval_gen[i]
            y_true.extend(y)
            y_pred.extend(model.predict(x, verbose=0).argmax(axis=1))
        
        plot_confusion_matrix(np.array(y_true), np.array(y_pred), le.classes_)
        
        return results
    
    except Exception as e:
        print(f"Evaluation failed: {e}")
        cleanup_gpu_memory()
        raise
    
def plot_confusion_matrix(y_true, y_pred, classes):
    """Plot a detailed confusion matrix"""
    import matplotlib.pyplot as plt
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    
    cm = confusion_matrix(y_true, y_pred)
    fig, ax = plt.subplots(figsize=(12, 10))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=classes)
    disp.plot(include_values=True, ax=ax, cmap='viridis',
              xticks_rotation='vertical', values_format='d')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()    

In [None]:
# === Execution Options ===
if __name__ == "__main__":
    # Option 1: Train and evaluate
    model, history = train()
    evaluate_saved_model('training6_best_model.h5')
    
    # Option 2: Only evaluate a previously trained model
    # evaluate_saved_model('training6_best_model.h5', 
    #                     test_data_path="path/to/test/images",
    #                     test_csv_path="path/to/test_metadata.csv")