In [1]:
# ===============================
# 1. Imports
# ===============================
import streamlit as st
import tensorflow as tf
import numpy as np
from PIL import Image

# ===============================
# 2. Load & Train Model (cached)
# ===============================
@st.cache_resource
def load_model():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()

    x_train = x_train / 255.0
    x_test = x_test / 255.0

    x_train = x_train.reshape(-1, 28, 28, 1)
    x_test = x_test.reshape(-1, 28, 28, 1)

    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2,2),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    model.fit(x_train, y_train, epochs=1, validation_data=(x_test, y_test), verbose=0)
    return model

model = load_model()

# ===============================
# 3. Class Names
# ===============================
class_names = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"
]

# ===============================
# 4. Prediction Function
# ===============================
def predict(image):
    image = image.convert("L")
    image = image.resize((28, 28))
    img = np.array(image) / 255.0
    img = img.reshape(1, 28, 28, 1)

    prediction = model.predict(img)
    return class_names[np.argmax(prediction)]

# ===============================
# 5. Streamlit UI
# ===============================
st.title("ðŸ§  Fashion MNIST CNN Classifier")
st.write("Upload an image of clothing and the CNN will predict the category.")

uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "png", "jpeg"])

if uploaded_file is not None:
    image = Image.open(uploaded_file)
    st.image(image, caption="Uploaded Image", width=200)

    if st.button("Predict"):
        result = predict(image)
        st.success(f"Predicted Class: **{result}**")


ModuleNotFoundError: No module named 'streamlit'