In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from wandb.integration.keras import WandbCallback
from tensorflow.keras.callbacks import ModelCheckpoint
import wandb
from huggingface_hub import HfApi

1. Data Preparation
Load and preprocess the TrashNet dataset.

In [None]:
def load_and_preprocess_data(target_size=(224, 224)):
    """
    Load the TrashNet dataset and preprocess images and labels.
    """
    dataset = load_dataset("garythung/trashnet")
    data_train = dataset['train']

    images, labels = [], []

    for example in data_train:
        img = example['image'].resize(target_size)
        img = np.array(img) / 255.0
        images.append(img)
        labels.append(example['label'])

    X = np.array(images)
    y = np.array(labels)
    num_classes = len(set(y))
    y = to_categorical(y, num_classes)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    return X_train, X_test, y_train, y_test, num_classes

2. Model Training
Bild and train a CNN model.

In [None]:
def build_cnn_model(input_shape, num_classes):
    """
    Build a basic CNN model.
    """
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    return model

def train_model(X_train, y_train, X_test, y_test, num_classes, project_name, entity_name):
    """
    Train a CNN model and log to WandB.
    """
    wandb.init(project=project_name, entity=entity_name, name="cnn_baseline")

    cnn_model = build_cnn_model(X_train.shape[1:], num_classes)
    cnn_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    history = cnn_model.fit(
        X_train, y_train,
        epochs=10,
        batch_size=32,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbCallback(log_graph=False),
            ModelCheckpoint(filepath="trashnet_model.keras", save_best_only=True)
        ]
    )

    model_path = "trashnet_model"
    cnn_model.save(model_path)
    return cnn_model, history, model_path


3. Evaluation
Evaluate the model and visualize results.

In [None]:
def plot_training(history):
    """
    Plot accuracy and loss during training.
    """
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title("Accuracy")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title("Loss")
    plt.legend()
    plt.show()

def evaluate_model(cnn_model, X_test, y_test):
    """
    Evaluate model performance and show confusion matrix.
    """
    loss, acc = cnn_model.evaluate(X_test, y_test)
    print(f"Test Accuracy: {acc*100:.2f}%")

    y_pred = np.argmax(cnn_model.predict(X_test), axis=1)
    y_true = np.argmax(y_test, axis=1)

    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=range(y_test.shape[1]))
    disp.plot(cmap='viridis')
    plt.title("Confusion Matrix")
    plt.show()

4. Utilities
Helper functions for visualization.

In [None]:
def show_sample_images(X, y, num_classes=6):
    """
    Display sample images for each class.
    """
    fig, axes = plt.subplots(1, num_classes, figsize=(15, 5))
    for i in range(num_classes):
        idx = np.where(y.argmax(axis=1) == i)[0][0]
        axes[i].imshow(X[idx])
        axes[i].axis('off')
        axes[i].set_title(f"Class {i}")
    plt.show()

def plot_class_distribution(y_train):
    """
    Plot class distribution.
    """
    sns.countplot(x=y_train.argmax(axis=1), palette="Set2")
    plt.title("Class Distribution in Training Data")
    plt.show()

5. Run the Pipeline

In [None]:
if __name__ == "__main__":
    # Data preparation
    X_train, X_test, y_train, y_test, num_classes = load_and_preprocess_data()
    show_sample_images(X_train, y_train, num_classes)
    plot_class_distribution(y_train)

    # Model training
    project_name = "trashnet-classification"
    entity_name = "your_entity_name"

    model, history, model_path = train_model(X_train, y_train, X_test, y_test, num_classes, project_name, entity_name)
    
    # Evaluation
    plot_training(history)
    evaluate_model(model, X_test, y_test)