In [18]:
import numpy as np
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import layers, models


In [20]:
# Load and preprocess MNIST dataset
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()

In [22]:
X_train = X_train.reshape(-1, 28, 28, 1) / 255.0
X_test = X_test.reshape(-1, 28, 28, 1) / 255.0
# Convert labels to binary masks (simplified for demonstration)
# We treat all non-zero pixels as the digit (foreground = 1), and zeros as the background
y_train_bin = np.where(X_train > 0, 1, 0)
y_test_bin = np.where(X_test > 0, 1, 0)


In [32]:
def create_model():
    model = models.Sequential()
    model.add(layers.Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
    model.add(layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
    model.add(layers.Reshape((7 * 7, 256)))
    model.add(layers.GRU(128, return_sequences=True))
    model.add(layers.Reshape((7, 7, 128)))
    model.add(layers.Conv2DTranspose(128, (3, 3), activation='relu', padding='same'))
    model.add(layers.UpSampling2D(size=(2, 2)))
    model.add(layers.Conv2DTranspose(64, (3, 3), activation='relu', padding='same'))
    model.add(layers.UpSampling2D(size=(2, 2)))
    model.add(layers.Conv2D(1, (1, 1), activation='sigmoid'))  # For binary segmentation
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

    return model


In [34]:
model = create_model()
model.summary()

In [35]:
model.fit(X_train, y_train_bin, epochs=10, batch_size=32, validation_split=0.2)

Epoch 1/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m662s[0m 405ms/step - accuracy: 0.9008 - loss: 0.2062 - val_accuracy: 0.9225 - val_loss: 0.1441
Epoch 2/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m613s[0m 400ms/step - accuracy: 0.9227 - loss: 0.1408 - val_accuracy: 0.9234 - val_loss: 0.1327
Epoch 3/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m610s[0m 391ms/step - accuracy: 0.9235 - loss: 0.1310 - val_accuracy: 0.9240 - val_loss: 0.1278
Epoch 4/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m706s[0m 446ms/step - accuracy: 0.9241 - loss: 0.1275 - val_accuracy: 0.9244 - val_loss: 0.1261
Epoch 5/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m790s[0m 526ms/step - accuracy: 0.9186 - loss: 0.1493 - val_accuracy: 0.9241 - val_loss: 0.1277
Epoch 6/10
[1m1500/1500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m517s[0m 334ms/step - accuracy: 0.9244 - loss: 0.1268 - val_accuracy: 0.9245 - val_loss:

<keras.src.callbacks.history.History at 0x2bb8116f590>

In [41]:
test_loss, test_acc = model.evaluate(X_test, y_test_bin)
print(f"Test Accuracy: {test_acc}")

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 71ms/step - accuracy: 0.9252 - loss: 0.1233
Test Accuracy: 0.9248540997505188
