In [1]:
# Set environment variables BEFORE any imports
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppress TensorFlow info/warning messages
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"  # Disable oneDNN to prevent memory issues
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"  # Avoid protobuf issues

print("[STARTUP] Environment variables configured", flush=True)

# Global configuration for the instrument recognition system
CONFIG = {
    "sample_rate": 22050,
    "mel_bands": [64, 96, 128],  # 3 resolutions for higher accuracy
    "n_fft": 2048,
    "hop_length": 512,
    "learning_rate": 0.0001,  # Lower learning rate for better convergence
    "class_map": {  # Default mapping, will be dynamically updated based on data
        0: "cello",
        1: "clarinet",
        2: "flute",
        3: "acoustic_guitar",
        4: "electric_guitar",
        5: "organ",
        6: "piano",
        7: "saxophone",
        8: "trumpet",
        9: "violin",
        10: "voice",
    },
    "test_size": 0.2,
    "random_state": 42,
    "max_files_per_instrument": 250,  # Increased for high accuracy (95-97% target)
    "augmentation_ratio": 0.7,  # High augmentation for better generalization
}

print("[STARTUP] Loading libraries (this may take 30-60 seconds)...", flush=True)
import librosa
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, Input, Model
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import signal

# Configure TensorFlow memory growth to prevent OOM
try:
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"[MEMORY] Configured {len(gpus)} GPU(s) with memory growth", flush=True)
    else:
        print("[MEMORY] Running on CPU - using optimized memory settings", flush=True)
except Exception as e:
    print(f"[MEMORY] Using default memory configuration: {e}", flush=True)

print("[STARTUP] All libraries loaded successfully!", flush=True)


# MultiResolutionCNN: Multi-input CNN for multi-resolution mel spectrograms
class MultiResolutionCNN:
    def __init__(self, input_shapes, num_classes):
        # input_shapes: list of shapes, e.g. [(64, 259, 1), (96, 259, 1), (128, 259, 1)]
        inputs = []
        processed = []
        for shape in input_shapes:
            inp = Input(shape=shape)
            # First conv block with BatchNorm and stronger regularization
            x = layers.Conv2D(
                64,  # Full capacity for high accuracy
                (3, 3),
                activation="relu",
                padding="same",
                kernel_regularizer=tf.keras.regularizers.l2(0.0001),
            )(inp)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            x = layers.Dropout(0.35)(x)

            # Second conv block
            x = layers.Conv2D(
                128,  # Full capacity for high accuracy
                (3, 3),
                activation="relu",
                padding="same",
                kernel_regularizer=tf.keras.regularizers.l2(0.0001),
            )(x)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            x = layers.Dropout(0.4)(x)

            # Third conv block
            x = layers.Conv2D(
                256,  # Full capacity for high accuracy
                (3, 3),
                activation="relu",
                padding="same",
                kernel_regularizer=tf.keras.regularizers.l2(0.0001),
            )(x)
            x = layers.BatchNormalization()(x)
            x = layers.MaxPooling2D((2, 2))(x)
            x = layers.Dropout(0.45)(x)

            # Fourth conv block for deeper features
            x = layers.Conv2D(
                256,  # Full capacity for high accuracy
                (3, 3),
                activation="relu",
                padding="same",
                kernel_regularizer=tf.keras.regularizers.l2(0.0001),
            )(x)
            x = layers.BatchNormalization()(x)

            # Global pooling
            x = layers.GlobalAveragePooling2D()(x)
            processed.append(x)
            inputs.append(inp)
        # Concatenate features from all resolutions
        if len(processed) > 1:
            x = layers.Concatenate()(processed)
        else:
            x = processed[0]
        # Dense layers with stronger regularization and dropout
        x = layers.Dense(
            512,
            activation="relu",
            kernel_regularizer=tf.keras.regularizers.l2(0.00015),  # Full capacity for high accuracy
        )(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(0.6)(x)
        x = layers.Dense(
            256,
            activation="relu",
            kernel_regularizer=tf.keras.regularizers.l2(0.00015),  # Full capacity for high accuracy
        )(x)
        x = layers.BatchNormalization()(x)
        x = layers.Dropout(0.55)(x)
        x = layers.Dense(
            128,
            activation="relu",
            kernel_regularizer=tf.keras.regularizers.l2(0.00015),  # Full capacity for high accuracy
        )(x)
        x = layers.Dropout(0.5)(x)
        output = layers.Dense(num_classes, activation="sigmoid")(x)
        self.model = Model(inputs=inputs, outputs=output)
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(
                learning_rate=CONFIG.get("learning_rate", 0.0002)
            ),
            loss="binary_crossentropy",
            metrics=["accuracy", tf.keras.metrics.AUC(name="auc")],  # Added AUC metric
        )

    def train(
        self, X_train, y_train, X_val, y_val, epochs=100
    ):  # Increased for better accuracy
        callbacks = [
            tf.keras.callbacks.EarlyStopping(
                patience=30,  # Increased patience for better convergence
                restore_best_weights=True,
                monitor="val_loss",
                min_delta=0.0001,
            ),
            tf.keras.callbacks.ModelCheckpoint(
                "best_model.keras",
                save_best_only=True,
                monitor="val_loss",  # Changed to .keras format
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                factor=0.6, patience=10, min_lr=5e-7, monitor="val_loss", verbose=1
            ),
        ]
        history = self.model.fit(
            X_train,
            y_train,
            validation_data=(X_val, y_val),
            epochs=epochs,
            batch_size=12,  # Increased for high accuracy training
            callbacks=callbacks,
            verbose=1,  # More detailed progress output
        )
        return history

    def evaluate(self, X_test, y_test):
        return self.model.evaluate(X_test, y_test, verbose=0)


class AudioProcessor:
    def __init__(self, config):
        self.config = config
        self.sample_rate = config.get("sample_rate", 22050)
        self.mel_bands = config.get("mel_bands", [64, 96, 128])
        self.n_fft = config.get("n_fft", 2048)
        self.hop_length = config.get("hop_length", 512)

    def load_audio(self, file_path):
        audio, _ = librosa.load(file_path, sr=self.sample_rate, mono=True)
        return audio

    def augment_audio(self, audio):
        """Apply enhanced random audio augmentation for better generalization"""
        import random

        augmented = audio.copy()

        rand = random.random()

        # Random time stretch (0.85x to 1.15x) - wider range for diversity
        if rand < 0.45:
            rate = random.uniform(0.85, 1.15)
            augmented = librosa.effects.time_stretch(augmented, rate=rate)

        # Random pitch shift (-2.5 to +2.5 semitones) - wider range
        elif rand < 0.8:
            n_steps = random.uniform(-2.5, 2.5)
            augmented = librosa.effects.pitch_shift(
                augmented, sr=self.sample_rate, n_steps=n_steps
            )

        # Add random noise (50% probability) - increased for robustness
        if random.random() < 0.5:
            noise = np.random.randn(len(augmented)) * random.uniform(0.003, 0.008)
            augmented = augmented + noise

        # Random volume adjustment (50% probability)
        if random.random() < 0.5:
            volume_factor = random.uniform(0.75, 1.25)
            augmented = augmented * volume_factor

        # Random low-pass filter (25% probability) - NEW for diversity
        if random.random() < 0.25:
            cutoff_freq = random.uniform(3000, 8000)
            sos = signal.butter(
                5, cutoff_freq, btype="low", fs=self.sample_rate, output="sos"
            )
            augmented = signal.sosfilt(sos, augmented)

        # Random time shift (30% probability) - NEW for temporal variation
        if random.random() < 0.3:
            shift = random.randint(
                -int(0.1 * self.sample_rate), int(0.1 * self.sample_rate)
            )
            augmented = np.roll(augmented, shift)

        return augmented

    def extract_multi_resolution_features(
        self, audio, target_time_dim=259
    ):  # Full temporal resolution for high accuracy
        features = {}
        for n_mels in self.mel_bands:
            mel = librosa.feature.melspectrogram(
                y=audio,
                sr=self.sample_rate,
                n_fft=self.n_fft,
                hop_length=self.hop_length,
                n_mels=n_mels,
                power=2.0,
            )
            mel_db = librosa.power_to_db(mel, ref=np.max)
            # Normalize each mel spectrogram (mean=0, std=1)
            mel_db = (mel_db - np.mean(mel_db)) / (np.std(mel_db) + 1e-8)
            # Ensure float32 to reduce memory
            mel_db = mel_db.astype(np.float32)
            # Pad or crop to target_time_dim
            if mel_db.shape[1] < target_time_dim:
                pad_width = target_time_dim - mel_db.shape[1]
                mel_db = np.pad(mel_db, ((0, 0), (0, pad_width)), mode="constant")
            elif mel_db.shape[1] > target_time_dim:
                mel_db = mel_db[:, :target_time_dim]
            features[f"mel_{n_mels}"] = np.expand_dims(mel_db, axis=-1)
        return features

    def mixup_data(self, X, y, alpha=0.4):
        """Apply mixup augmentation during training for better generalization"""
        lam = np.random.beta(alpha, alpha)
        batch_size = len(X[0]) if isinstance(X, list) else len(X)
        index = np.random.permutation(batch_size)

        if isinstance(X, list):
            mixed_X = [lam * x + (1 - lam) * x[index] for x in X]
        else:
            mixed_X = lam * X + (1 - lam) * X[index]

        mixed_y = lam * y + (1 - lam) * y[index]
        return mixed_X, mixed_y


# Feature caching for faster training
import hashlib
import pickle

def get_cache_path(file_path, augment=False):
    """Generate cache file path for audio features"""
    file_hash = hashlib.md5(file_path.encode()).hexdigest()
    suffix = "_aug" if augment else ""
    version = "_v6"  # Updated version with enhanced augmentation
    return f"CNN/cache/{file_hash}{suffix}{version}.pkl"

def load_cached_features(file_path, augment=False):
    """Load features from cache if available"""
    cache_path = get_cache_path(file_path, augment)
    if os.path.exists(cache_path):
        with open(cache_path, "rb") as f:
            return pickle.load(f)
    return None

def save_cached_features(file_path, features, augment=False):
    """Save features to cache"""
    cache_dir = "CNN/cache"
    os.makedirs(cache_dir, exist_ok=True)
    cache_path = get_cache_path(file_path, augment)
    with open(cache_path, "wb") as f:
        pickle.dump(features, f)


if __name__ == "__main__":
    print("\n" + "=" * 70, flush=True)
    print("    CNN-BASED MUSIC INSTRUMENT RECOGNITION SYSTEM", flush=True)
    print("    [HIGH ACCURACY MODE: TARGET 95-97% ACCURACY]", flush=True)
    print("    [250 files/instrument | 70% augmentation | 3 resolutions]", flush=True)
    print("=" * 70 + "\n", flush=True)

    # Initialize components
    print("[INIT] Initializing audio processor...", flush=True)
    processor = AudioProcessor(CONFIG)
    print(
        f"[INIT] Using enhanced augmentation with mixup strategy",
        flush=True,
    )
    print("[INIT] Audio processor ready\n", flush=True)

    # Load dataset from IRMAS-TrainingData
    import glob

    print("[DATA] Loading training data from IRMAS dataset...", flush=True)
    data_dir = r"/kaggle/input/cnn-based-music-instrument-recognition-system"
    print(f"[DATA] Dataset location: {data_dir}", flush=True)
    instrument_folders = [
        f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))
    ]

    # Mapping actual dataset folder names to the standardized instrument names (from CONFIG['class_map'])
    dataset_folder_to_class_map = {
        "Violin": "violin",
        "Piano": "piano",
        "Organ": "organ",
        "Saxophone": "saxophone",
        "Trumpet": "trumpet",
        "Electro_Guitar": "electric_guitar",  # Specific mapping for 'Electro_Guitar' folder
        "flute": "flute",  # Specific mapping for 'flute' folder (lowercase)
        "Clarinet": "clarinet",
        "Acoustic_Guitar": "acoustic_guitar",
        "Tambourine": "tambourine",
        "vibraphone": "vibraphone",
        "Trombone": "trombone",
        "Shakers": "shakers",
        "Ukulele": "ukulele",
        "Keyboard": "keyboard",
        "Horn": "horn",
        "Harmonica": "harmonica",
        "Mandolin": "mandolin",
        "Hi_Hats": "hi_hats",
        "Floor_Tom": "floor_tom",
        "Drum_set": "drum_set",
        "Harmonium": "harmonium",
        "Dobro": "dobro",
        "Cymbals": "cymbals",
        "Banjo": "banjo",
        "Accordion": "accordion",
        "cowbell": "cowbell",
        "Bass_Guitar": "bass_guitar",
    }

    # Collect unique instruments dynamically
    discovered_instruments = set()
    audio_files = []
    labels = []

    # For mel spectrogram visualization
    sample_files_per_instrument = {}

    # Get a set of canonical class names for direct matching or normalization check
    # This is primarily for the original, smaller class_map. The comprehensive mapping above handles the rest.
    canonical_class_names = set(CONFIG["class_map"].values())

    for inst_folder_name in instrument_folders:
        mapped_label = None

        # 1. Check direct mapping from dataset_folder_to_class_map
        if inst_folder_name in dataset_folder_to_class_map:
            mapped_label = dataset_folder_to_class_map[inst_folder_name]
        # 2. If not found in explicit map, try normalizing folder name and checking against original canonical class names
        #    This secondary check is less critical now that dataset_folder_to_class_map is comprehensive.
        else:
            normalized_folder_name = inst_folder_name.lower().replace(" ", "_")
            if normalized_folder_name in canonical_class_names:
                mapped_label = normalized_folder_name

        if mapped_label:
            print(
                f"[DATA] Scanning folder: {inst_folder_name} (mapped to {mapped_label})",
                flush=True,
            )
            wav_files = glob.glob(os.path.join(data_dir, inst_folder_name, "*.wav"))

            if not wav_files:
                print(
                    f"[DATA]   No .wav files found in {inst_folder_name}. Skipping.",
                    flush=True,
                )
                continue

            discovered_instruments.add(mapped_label)
            audio_files.extend(wav_files)
            labels.extend([mapped_label] * len(wav_files))

            # Store first file for mel spectrogram visualization
            if mapped_label not in sample_files_per_instrument and len(wav_files) > 0:
                sample_files_per_instrument[mapped_label] = wav_files[0]

            print(
                f"[DATA]   Found {len(wav_files)} files for {mapped_label}",
                flush=True,
            )
        else:
            print(
                f"[DATA] Skipping folder: {inst_folder_name} (no mapping found to a known instrument)",
                flush=True,
            )

    # Update CONFIG with discovered instruments
    print(
        f"\n[DATA] Discovered {len(discovered_instruments)} unique instruments: {sorted(discovered_instruments)}",
        flush=True,
    )

    # Limit files per instrument
    max_files_per_instrument = CONFIG["max_files_per_instrument"]
    print(
        f"\n[DATA] Using {max_files_per_instrument} files per instrument...",
        flush=True,
    )
    filtered_files = []
    filtered_labels = []
    instrument_counts = {}

    for file, label in zip(audio_files, labels):
        if label not in instrument_counts:
            instrument_counts[label] = 0
        if instrument_counts[label] < max_files_per_instrument:
            filtered_files.append(file)
            filtered_labels.append(label)
            instrument_counts[label] += 1

    audio_files = filtered_files
    labels = filtered_labels

    print(f"[DATA] Total files to process: {len(audio_files)}", flush=True)
    for label, count in instrument_counts.items():
        print(f"[DATA]   {label}: {count} files", flush=True)

    # Generate mel spectrogram visualizations
    print(
        f"\n[VISUALIZATION] Generating mel spectrograms for all instruments...",
        flush=True,
    )
    # Adjust subplot grid based on the number of discovered instruments
    num_instruments_for_viz = len(sample_files_per_instrument)
    rows = (num_instruments_for_viz + 2) // 3  # Calculate rows dynamically
    fig, axes = plt.subplots(rows, 3, figsize=(15, 4 * rows))
    axes = axes.ravel()

    for i, (instrument, file_path) in enumerate(
        sorted(sample_files_per_instrument.items())
    ):
        try:
            audio = processor.load_audio(file_path)
            mel = librosa.feature.melspectrogram(
                y=audio,
                sr=CONFIG["sample_rate"],
                n_fft=CONFIG["n_fft"],
                hop_length=CONFIG["hop_length"],
                n_mels=128,
            )
            mel_db = librosa.power_to_db(mel, ref=np.max)

            img = librosa.display.specshow(
                mel_db,
                sr=CONFIG["sample_rate"],
                hop_length=CONFIG["hop_length"],
                x_axis="time",
                y_axis="mel",
                ax=axes[i],
                cmap="viridis",
            )
            axes[i].set_title(f"{instrument}", fontsize=12, fontweight="bold")
            axes[i].set_xlabel("Time (s)", fontsize=10)
            axes[i].set_ylabel("Frequency (Hz)", fontsize=10)
            fig.colorbar(img, ax=axes[i], format="%+2.0f dB")
        except Exception as e:
            print(
                f"[WARNING] Could not generate mel spectrogram for {instrument}: {e}",
                flush=True,
            )

    # Hide extra subplots if any
    for i in range(num_instruments_for_viz, len(axes)):
        axes[i].axis("off")

    plt.tight_layout()
    plt.savefig("mel_spectrograms.png", dpi=150, bbox_inches="tight")
    plt.close()
    print(f"[SAVE] Mel spectrograms saved to mel_spectrograms.png", flush=True)

    print(
        "\n[FEATURES] Extracting multi-resolution mel spectrograms with enhanced augmentation...",
        flush=True,
    )

    # Multi-label binarizer for instrument labels
    from sklearn.preprocessing import MultiLabelBinarizer

    X, y = [], []
    processed_files = []
    count = 0

    # Process files with augmentation
    for file, label in zip(audio_files, labels):
        if count % 50 == 0:
            print(
                f"[FEATURES] Processed {count}/{len(audio_files)} files...", flush=True
            )
        try:
            # Try to load from cache first
            cached_features = load_cached_features(file, augment=False)
            if cached_features is not None:
                features = cached_features
            else:
                audio = processor.load_audio(file)
                features = processor.extract_multi_resolution_features(audio)
                save_cached_features(file, features, augment=False)

            X.append([features[f"mel_{n}"] for n in CONFIG["mel_bands"]])
            y.append(
                [label] if isinstance(label, str) else list(label)
            )  # Ensure label is always a list for MLB
            processed_files.append(file)
            count += 1

            # Augmentation with higher ratio for better generalization
            if count <= len(audio_files) * CONFIG["augmentation_ratio"]:
                cached_aug = load_cached_features(file, augment=True)
                if cached_aug is not None:
                    aug_features = cached_aug
                else:
                    audio = processor.load_audio(file)
                    aug_audio = processor.augment_audio(audio)
                    aug_features = processor.extract_multi_resolution_features(
                        aug_audio
                    )
                    save_cached_features(file, aug_features, augment=True)

                X.append([aug_features[f"mel_{n}"] for n in CONFIG["mel_bands"]])
                y.append([label] if isinstance(label, str) else list(label))
                processed_files.append(file + "_aug")

            if count % 50 == 0:
                print(
                    f"Processed {count}/{len(audio_files)} files (with augmentation: {len(X)} samples)..."
                )
                import gc

                gc.collect()  # More frequent garbage collection

        except Exception as e:
            print(f"Skipping file {file} due to error: {e}")

    # Only use the successfully processed samples
    min_len = min(len(X), len(y), len(processed_files))
    X = X[:min_len]
    y = y[:min_len]
    processed_files = processed_files[:min_len]

    # Aggressive memory cleanup
    import gc

    del (
        audio_files,
        labels,
        filtered_files,
        filtered_labels,
    )  # Delete large unused variables
    gc.collect()
    print(f"\n[MEMORY] Aggressive garbage collection completed", flush=True)

    # Get unique instruments
    # Ensure unique_instruments correctly captures all labels from 'y'
    unique_instruments = sorted(
        list(set(l[0] for l in y if l))
    )  # Extract unique labels from nested lists

    print(
        f"[DATA] Training with {len(unique_instruments)} instrument classes: {unique_instruments}",
        flush=True,
    )

    # Binarize labels
    mlb = MultiLabelBinarizer(classes=unique_instruments)
    y = mlb.fit_transform(y)

    print(f"[DATA] Label shape after binarization: {y.shape}", flush=True)

    if len(X) < 2 or len(y) < 2:
        print("Not enough valid samples to train/test. Please add more data.")
    else:
        print(f"Length of X: {len(X)}")
        print(f"Length of y: {len(y)}")
        print(f"Length of processed_files: {len(processed_files)}")

        # Stack samples
        X_np = [
            np.stack([sample[i] for sample in X], axis=0)
            for i in range(len(CONFIG["mel_bands"]))
        ]

        # Print shapes
        print(f"\n[FEATURES] Feature shapes:")
        for i, shape in enumerate([x.shape for x in X_np]):
            print(f"[FEATURES]   Resolution {i+1}: {shape}")

        # Train-test split
        from sklearn.model_selection import train_test_split

        idx = np.arange(len(y))
        train_idx, test_idx, y_train, y_test = train_test_split(
            idx, y, test_size=CONFIG["test_size"], random_state=CONFIG["random_state"]
        )
        X_train = [x[train_idx] for x in X_np]
        X_test = [x[test_idx] for x in X_np]

        print(f"\nTraining samples: {len(train_idx)}, Test samples: {len(test_idx)}")
        print(f"Training data shapes: {[x.shape for x in X_train]}")

        # Initialize model
        num_classes = y.shape[1]
        input_shapes = [(x.shape[1], x.shape[2], 1) for x in X_np]
        print(
            f"\n[MODEL] Creating optimized CNN model for {num_classes} classes...",
            flush=True,
        )
        print(f"[MODEL] Input shapes: {input_shapes}", flush=True)
        print(
            f"[MODEL] Architecture: 4 Conv blocks + 3 Dense layers with strong regularization",
            flush=True,
        )
        model = MultiResolutionCNN(
            input_shapes=input_shapes,
            num_classes=num_classes,
        )

        print(f"\n[MODEL] Model Summary:")
        model.model.summary()

        # Train model
        print(
            f"\n[TRAINING] Starting training with high-accuracy parameters...",
            flush=True,
        )
        print(f"[TRAINING] Batch size: 12 (optimized for accuracy)", flush=True)
        print(
            f"[TRAINING] Max epochs: 100 with early stopping (patience=30)\n",
            flush=True,
        )
        history = model.train(X_train, y_train, X_test, y_test)

        # Evaluate
        test_loss, test_acc, test_auc = model.evaluate(X_test, y_test)
        print(f"\n[RESULTS] Final Test Accuracy: {test_acc:.2%}")
        print(f"[RESULTS] Final Test AUC: {test_auc:.4f}")

        # Save model
        model.model.save("instrument_classifier_v3_optimized.keras")
        print(
            f"\n[SAVE] Model saved to instrument_classifier_v3_optimized.keras",
            flush=True,
        )

        # Enhanced Visualization
        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        # Plot 1: Accuracy
        axes[0].plot(
            history.history["accuracy"], label="Training Accuracy", linewidth=2
        )
        axes[0].plot(
            history.history["val_accuracy"], label="Validation Accuracy", linewidth=2
        )
        axes[0].set_title("Model Accuracy Over Epochs", fontsize=14, fontweight="bold")
        axes[0].set_xlabel("Epoch", fontsize=12)
        axes[0].set_ylabel("Accuracy", fontsize=12)
        axes[0].legend(fontsize=10)
        axes[0].grid(True, alpha=0.3)

        # Calculate and display final gap
        final_train_acc = history.history["accuracy"][-1]
        final_val_acc = history.history["val_accuracy"][-1]
        gap = abs(final_train_acc - final_val_acc)
        gap_status = (
            "✓ Excellent" if gap < 0.05 else "✓ Good" if gap < 0.08 else "⚠ Check"
        )
        axes[0].text(
            0.02,
            0.98,
            f"Final Gap: {gap:.2%} {gap_status}",
            transform=axes[0].transAxes,
            fontsize=10,
            verticalalignment="top",
            bbox=dict(
                boxstyle="round",
                facecolor="lightgreen" if gap < 0.05 else "wheat",
                alpha=0.7,
            ),
        )

        # Plot 2: Loss
        axes[1].plot(history.history["loss"], label="Training Loss", linewidth=2)
        axes[1].plot(history.history["val_loss"], label="Validation Loss", linewidth=2)
        axes[1].set_title("Model Loss Over Epochs", fontsize=14, fontweight="bold")
        axes[1].set_xlabel("Epoch", fontsize=12)
        axes[1].set_ylabel("Loss", fontsize=12)
        axes[1].legend(fontsize=10)
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig("training_analysis.png", dpi=150, bbox_inches="tight")
        plt.close()
        print(f"[SAVE] Training analysis saved to training_analysis.png", flush=True)

        # Generate confusion matrix and classification report
        print(
            f"\n[EVALUATION] Generating confusion matrix and classification report...",
            flush=True,
        )
        y_pred = model.model.predict(X_test, verbose=0)
        y_pred_classes = (y_pred > 0.5).astype(int)
        y_test_classes = y_test.astype(int)

        # Get class names
        class_names = list(mlb.classes_)

        # Multi-label classification report
        print(f"\n{'='*60}")
        print(f"CLASSIFICATION REPORT (Per Instrument)")
        print(f"{'='*60}")
        report = classification_report(
            y_test_classes, y_pred_classes, target_names=class_names, zero_division=0
        )
        print(report)

        # Confusion matrix for each instrument
        # Adjust subplot grid based on the number of discovered instruments
        num_classes_cm = len(class_names)
        rows_cm = (num_classes_cm + 2) // 3  # Calculate rows dynamically
        fig, axes = plt.subplots(rows_cm, 3, figsize=(15, 4 * rows_cm))
        axes = axes.ravel()

        for i, instrument in enumerate(class_names):
            cm = confusion_matrix(y_test_classes[:, i], y_pred_classes[:, i])
            sns.heatmap(
                cm,
                annot=True,
                fmt="d",
                cmap="Blues",
                ax=axes[i],
                xticklabels=["Negative", "Positive"],
                yticklabels=["Negative", "Positive"],
            )
            axes[i].set_title(f"{instrument}", fontsize=12, fontweight="bold")
            axes[i].set_ylabel("True Label", fontsize=10)
            axes[i].set_xlabel("Predicted Label", fontsize=10)

        # Hide extra subplots if any
        for i in range(num_classes_cm, len(axes)):
            axes[i].axis("off")

        plt.tight_layout()
        plt.savefig("confusion_matrices.png", dpi=150, bbox_inches="tight")
        plt.close()
        print(f"[SAVE] Confusion matrices saved to confusion_matrices.png", flush=True)

        # Print detailed results
        print(f"\n{'='*60}")
        print(f"FINAL TRAINING RESULTS")
        print(f"{'='*60}")
        print(f"Training Accuracy:   {final_train_acc:.2%}")
        print(f"Validation Accuracy: {final_val_acc:.2%}")
        gap_result = (
            "✓ Excellent - Target Achieved!"
            if gap < 0.05 and final_val_acc >= 0.95
            else "✓ Good" if gap < 0.08 else "⚠ Check for overfitting"
        )
        print(f"Accuracy Gap:        {gap:.2%} {gap_result}")
        print(f"Test Accuracy:       {test_acc:.2%}")
        test_result = "✓ TARGET ACHIEVED!" if test_acc >= 0.95 else "⚠ Below target"
        print(f"Test Status:         {test_result}")
        print(f"Test AUC:            {test_auc:.4f}")
        print(f"{'='*60}\n")

        # Sliding window inference for instrument intensity over time
        def sliding_window_predict(audio_file, window_size=1.0, hop_size=0.5):
            audio, sr = librosa.load(audio_file, sr=CONFIG["sample_rate"])
            duration = librosa.get_duration(y=audio, sr=sr)
            times = np.arange(0, duration - window_size, hop_size)
            all_preds = []
            for t in times:
                start = int(t * sr)
                end = int((t + window_size) * sr)
                segment = audio[start:end]
                if len(segment) < int(window_size * sr):
                    # Pad with zeros if segment is shorter than expected
                    pad = np.zeros(int(window_size * sr) - len(segment))
                    segment = np.concatenate([segment, pad])
                features = processor.extract_multi_resolution_features(segment)
                X_window = [
                    np.expand_dims(features[f"mel_{n}"], axis=0)
                    for n in CONFIG["mel_bands"]
                ]
                pred = model.model.predict(X_window, verbose=0)
                all_preds.append(pred[0])
            return np.array(all_preds), times

        # Example: run sliding window on first successfully processed file
        if len(processed_files) > 0:
            preds, times = sliding_window_predict(processed_files[0])
            # Visualization
            import matplotlib.pyplot as plt

            plt.figure(figsize=(10, 6))
            for i, inst in enumerate(mlb.classes_):
                plt.plot(times, preds[:, i], label=inst)
            plt.xlabel("Time (s)")
            plt.ylabel("Predicted Probability")
            plt.title("Instrument Intensity Over Time")
            plt.legend()
            plt.tight_layout()
            plt.savefig("instrument_intensity_timeline.png")
            plt.close()

            # JSON export
            import json

            result = {
                "audio_file": processed_files[0],
                "detected_instruments": {
                    inst: float(np.max(preds[:, i]))
                    for i, inst in enumerate(mlb.classes_)
                },
                "timeline": [
                    {
                        "time": float(t),
                        **{
                            inst: float(preds[j, i])
                            for i, inst in enumerate(mlb.classes_)
                        },
                    }
                    for j, t in enumerate(times)
                ],
            }
            with open("instrument_recognition_result.json", "w") as f:
                json.dump(result, f, indent=2)
            print(
                "Exported instrument recognition result to instrument_recognition_result.json"
            )
            # PDF export using matplotlib
            from matplotlib.backends.backend_pdf import PdfPages

            pdf_filename = "instrument_recognition_report.pdf"
            with PdfPages(pdf_filename) as pdf:
                # Page 1: Title and Summary
                fig = plt.figure(figsize=(8.5, 11))
                fig.text(
                    0.5,
                    0.95,
                    "VANIDYA AI - Instrument Recognition Report",
                    ha="center",
                    fontsize=16,
                    fontweight="bold",
                )
                fig.text(
                    0.5,
                    0.90,
                    f"Audio File: {processed_files[0]}",
                    ha="center",
                    fontsize=10,
                )
                fig.text(
                    0.5,
                    0.87,
                    f"Analysis Date: {__import__('datetime').datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
                    ha="center",
                    fontsize=9,
                    style="italic",
                )

                # Detected Instruments Summary
                fig.text(
                    0.1, 0.80, "Detected Instruments:", fontsize=12, fontweight="bold"
                )
                y_pos = 0.75
                for inst in mlb.classes_:
                    max_conf = float(np.max(preds[:, list(mlb.classes_).index(inst)]))
                    status = "Present" if max_conf > 0.5 else "Not Present"
                    color = "green" if max_conf > 0.5 else "red"
                    fig.text(
                        0.15,
                        y_pos,
                        f"• {inst.replace('_', ' ').title()}: {status} (Confidence: {max_conf:.2%})",
                        fontsize=10,
                        color=color,
                    )
                    y_pos -= 0.04

                # Statistics
                fig.text(
                    0.1,
                    y_pos - 0.05,
                    "Analysis Statistics:",
                    fontsize=12,
                    fontweight="bold",
                )
                y_pos -= 0.10
                fig.text(
                    0.15,
                    y_pos,
                    f"• Total Instruments Detected: {sum(1 for i in mlb.classes_ if np.max(preds[:, list(mlb.classes_).index(i)]) > 0.5)}",
                    fontsize=10,
                )
                y_pos -= 0.04
                fig.text(
                    0.15,
                    y_pos,
                    f"• Audio Duration: {times[-1]:.2f} seconds",
                    fontsize=10,
                )
                y_pos -= 0.04
                fig.text(
                    0.15, y_pos, f"• Time Windows Analyzed: {len(times)}", fontsize=10
                )

                plt.axis("off")
                pdf.savefig(fig, bbox_inches="tight")
                plt.close()

                # Page 2: Instrument Intensity Timeline
                fig = plt.figure(figsize=(11, 8.5))
                for i, inst in enumerate(mlb.classes_):
                    plt.plot(
                        times,
                        preds[:, i],
                        label=inst.replace("_", " ").title(),
                        linewidth=2,
                    )
                plt.xlabel("Time (seconds)", fontsize=12)
                plt.ylabel("Predicted Probability", fontsize=12)
                plt.title(
                    "Instrument Intensity Over Time", fontsize=14, fontweight="bold"
                )
                plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left", fontsize=9)
                plt.grid(True, alpha=0.3)
                plt.tight_layout()
                pdf.savefig(fig, bbox_inches="tight")
                plt.close()

                # Page 3: Confidence Bar Chart
                fig = plt.figure(figsize=(8.5, 11))
                instruments = [inst.replace("_", " ").title() for inst in mlb.classes_]
                max_confidences = [
                    float(np.max(preds[:, i])) for i in range(len(mlb.classes_))
                ]
                colors = [
                    "green" if conf > 0.5 else "orange" if conf > 0.3 else "red"
                    for conf in max_confidences
                ]

                plt.barh(instruments, max_confidences, color=colors, alpha=0.7)
                plt.xlabel("Maximum Confidence", fontsize=12)
                plt.title(
                    "Instrument Detection Confidence", fontsize=14, fontweight="bold"
                )
                plt.xlim(0, 1)
                plt.axvline(
                    x=0.5,
                    color="black",
                    linestyle="--",
                    linewidth=1,
                    alpha=0.5,
                    label="Threshold (0.5)",
                )
                plt.legend()
                plt.grid(True, alpha=0.3, axis="x")
                plt.tight_layout()
                pdf.savefig(fig, bbox_inches="tight")
                plt.close()

            print(f"[EXPORT] PDF report saved to {pdf_filename}")

print("[INFO] High Accuracy Target: 95-97% validation & test accuracy")
print("[INFO] Configuration: 250 files × 28 instruments = 7,000 base samples")

print("[INFO] With 70% augmentation: ~11,900 total training samples")
print("[INFO] Check results above to verify target achievement!")


[STARTUP] Environment variables configured
[STARTUP] Loading libraries (this may take 30-60 seconds)...


2026-01-16 17:28:51.456794: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768584531.642405      24 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768584531.694978      24 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768584532.129219      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768584532.129257      24 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768584532.129260      24 computation_placer.cc:177] computation placer alr

[MEMORY] Configured 2 GPU(s) with memory growth
[STARTUP] All libraries loaded successfully!

    CNN-BASED MUSIC INSTRUMENT RECOGNITION SYSTEM
    [HIGH ACCURACY MODE: TARGET 95-97% ACCURACY]
    [250 files/instrument | 70% augmentation | 3 resolutions]

[INIT] Initializing audio processor...
[INIT] Using enhanced augmentation with mixup strategy
[INIT] Audio processor ready

[DATA] Loading training data from IRMAS dataset...
[DATA] Dataset location: /kaggle/input/cnn-based-music-instrument-recognition-system
[DATA] Scanning folder: Piano (mapped to piano)
[DATA]   Found 575 files for piano
[DATA] Scanning folder: cowbell (mapped to cowbell)
[DATA]   Found 621 files for cowbell
[DATA] Scanning folder: Bass_Guitar (mapped to bass_guitar)
[DATA]   Found 3613 files for bass_guitar
[DATA] Scanning folder: Mandolin (mapped to mandolin)
[DATA]   Found 2458 files for mandolin
[DATA] Scanning folder: vibraphone (mapped to vibraphone)
[DATA]   Found 506 files for vibraphone
[DATA] Scanning fol

I0000 00:00:1768585050.273866      24 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1768585050.277691      24 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5



[MODEL] Model Summary:



[TRAINING] Starting training with high-accuracy parameters...
[TRAINING] Batch size: 12 (optimized for accuracy)
[TRAINING] Max epochs: 100 with early stopping (patience=30)

Epoch 1/100


I0000 00:00:1768585073.469972      79 service.cc:152] XLA service 0x7e1bcc0020f0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1768585073.470036      79 service.cc:160]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1768585073.470042      79 service.cc:160]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1768585075.256217      79 cuda_dnn.cc:529] Loaded cuDNN version 91002
2026-01-16 17:37:58.737191: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:37:58.955961: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:38:00.089403: E external/local_xl

[1m  1/775[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m10:26:05[0m 49s/step - accuracy: 0.0000e+00 - auc: 0.4532 - loss: 1.5280

I0000 00:00:1768585106.460223      79 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 84ms/step - accuracy: 0.0465 - auc: 0.5309 - loss: 1.0477

2026-01-16 17:39:35.463930: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:39:35.689289: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:39:36.855515: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:39:37.080511: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 17:39:38.405605: E external/local_xla/xla/stream_

[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 108ms/step - accuracy: 0.0465 - auc: 0.5310 - loss: 1.0474 - val_accuracy: 0.0473 - val_auc: 0.4978 - val_loss: 0.5954 - learning_rate: 1.0000e-04
Epoch 2/100
[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 92ms/step - accuracy: 0.0805 - auc: 0.6057 - loss: 0.5342 - val_accuracy: 0.0636 - val_auc: 0.5436 - val_loss: 0.5516 - learning_rate: 1.0000e-04
Epoch 3/100
[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 93ms/step - accuracy: 0.1269 - auc: 0.6759 - loss: 0.4870 - val_accuracy: 0.0744 - val_auc: 0.5345 - val_loss: 0.5888 - learning_rate: 1.0000e-04
Epoch 4/100
[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 95ms/step - accuracy: 0.1929 - auc: 0.7405 - loss: 0.4536 - val_accuracy: 0.0851 - val_auc: 0.5497 - val_loss: 0.5376 - learning_rate: 1.0000e-04
Epoch 5/100
[1m775/775[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m74s[0m 95ms/step - accuracy: 0.2675 - 

2026-01-16 19:41:42.680475: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:41:42.923813: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:41:44.751546: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:41:44.996039: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:41:47.198092: E external/local_xla/xla/stream_


[RESULTS] Final Test Accuracy: 96.82%
[RESULTS] Final Test AUC: 0.9971

[SAVE] Model saved to instrument_classifier_v3_optimized.keras
[SAVE] Training analysis saved to training_analysis.png

[EVALUATION] Generating confusion matrix and classification report...

CLASSIFICATION REPORT (Per Instrument)
                 precision    recall  f1-score   support

      accordion       0.94      0.86      0.90        51
acoustic_guitar       0.98      0.96      0.97        89
          banjo       0.98      1.00      0.99        83
    bass_guitar       0.98      1.00      0.99       112
       clarinet       1.00      0.96      0.98       109
        cowbell       1.00      1.00      1.00        93
        cymbals       0.85      1.00      0.92        77
          dobro       0.98      0.97      0.98       102
       drum_set       1.00      1.00      1.00        38
electric_guitar       1.00      0.92      0.96        96
      floor_tom       0.96      0.94      0.95        54
          fl

2026-01-16 19:42:20.036447: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:42:20.235336: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:42:20.959845: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:42:21.159136: E external/local_xla/xla/stream_executor/cuda/cuda_timer.cc:86] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
2026-01-16 19:42:21.896271: E external/local_xla/xla/stream_

Exported instrument recognition result to instrument_recognition_result.json
[EXPORT] PDF report saved to instrument_recognition_report.pdf
[INFO] High Accuracy Target: 95-97% validation & test accuracy
[INFO] Configuration: 250 files × 28 instruments = 7,000 base samples
[INFO] With 70% augmentation: ~11,900 total training samples
[INFO] Check results above to verify target achievement!
