In [None]:
"""
Fashion-MNIST Classification

Ten notebook wykorzystuje zbiór Fashion-MNIST (wbudowany w Keras) i trenuje
model CNN do klasyfikacji 10 klas ubrań.

Dodatkowo rysuje macierz pomyłek, a wyniki i logi zapisuje w katalogu `logs/`.
"""

import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf

from sklearn.metrics import confusion_matrix

os.makedirs('logs', exist_ok=True)

def load_fashion_mnist_data():
    """
    Wczytuje dane Fashion-MNIST z Keras.

    Zwraca
    -------
    X_train, X_test : ndarray
        Dane treningowe i testowe (28x28 obrazy w skali szarości).
    y_train, y_test : ndarray
        Etykiety treningowe i testowe (0-9).
    """
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
    return X_train, X_test, y_train, y_test

def preprocess_data(X_train, X_test, y_train, y_test):
    """
    Przetwarza dane Fashion-MNIST:
    - Normalizuje piksele do [0,1].
    - Dodaje wymiar kanału (1).
    - One-hot encoding etykiet.

    Parametry
    ---------
    X_train, X_test : ndarray
        Dane treningowe i testowe.
    y_train, y_test : ndarray
        Etykiety w formie integer.

    Zwraca
    -------
    X_train, X_test, y_train_cat, y_test_cat : ndarray
        Przetworzone dane i etykiety one-hot.
    """
    X_train = X_train.astype('float32')/255.0
    X_test = X_test.astype('float32')/255.0

    X_train = np.expand_dims(X_train, -1)
    X_test = np.expand_dims(X_test, -1)

    y_train_cat = tf.keras.utils.to_categorical(y_train, 10)
    y_test_cat = tf.keras.utils.to_categorical(y_test, 10)

    return X_train, X_test, y_train_cat, y_test_cat

def build_fashion_model(input_shape):
    """
    Buduje prosty model CNN dla Fashion-MNIST:
    - Conv2D + MaxPooling2D
    - Flatten + Dense + Dense(softmax)

    Parametry
    ---------
    input_shape : tuple
        Kształt wejścia, np. (28,28,1)

    Zwraca
    -------
    model : tf.keras.Model
        Skompilowany model CNN.
    """
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=input_shape),
        tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2,2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def train_and_evaluate_model(X_train, X_test, y_train_cat, y_test_cat, y_test):
    """
    Trenuje i ocenia model CNN na Fashion-MNIST.
    Zapisuje:
    - logs/results_fashion_mnist.txt
    - logs/training_history_fashion_mnist.png
    Rysuje macierz pomyłek i zapisuje do logs/confusion_matrix.png.

    Parametry
    ---------
    X_train, X_test, y_train_cat, y_test_cat : ndarray
        Dane i etykiety one-hot do treningu i testów.
    y_test : ndarray
        Oryginalne etykiety (integer) potrzebne do confusion matrix.
    """
    model = build_fashion_model((28,28,1))
    history = model.fit(X_train, y_train_cat, validation_split=0.2, epochs=10, batch_size=64, verbose=0)
    loss, acc = model.evaluate(X_test, y_test_cat, verbose=0)

    with open('logs/results_fashion_mnist.txt', 'w') as f:
        f.write(f"Test Accuracy: {acc}\n")

    plt.plot(history.history['accuracy'], label='train acc')
    plt.plot(history.history['val_accuracy'], label='val acc')
    plt.legend()
    plt.title('Fashion-MNIST Training History')
    plt.savefig('logs/training_history_fashion_mnist.png')
    plt.show()

    # Macierz pomyłek
    y_pred = model.predict(X_test)
    y_pred_classes = np.argmax(y_pred, axis=1)

    cm = confusion_matrix(y_test, y_pred_classes)

    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Fashion-MNIST Confusion Matrix')
    plt.savefig('logs/confusion_matrix.png')
    plt.show()

def main():
    """
    Funkcja główna:
    1. Wczytuje dane Fashion-MNIST.
    2. Przetwarza dane (normalizacja, one-hot, reshaping).
    3. Buduje i trenuje model CNN.
    4. Rysuje macierz pomyłek i zapisuje logi.
    """
    X_train, X_test, y_train, y_test = load_fashion_mnist_data()
    X_train, X_test, y_train_cat, y_test_cat = preprocess_data(X_train, X_test, y_train, y_test)
    train_and_evaluate_model(X_train, X_test, y_train_cat, y_test_cat, y_test)

main()
