# Install necessary packages

In [None]:
!pip install tensorflow numpy opencv-python matplotlib pillow scipy

# Import necessary packages

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Lambda, Conv2D, MaxPooling2D, BatchNormalization, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt

# Set Image and training parameters

In [None]:
img_width, img_height = 28, 28
batch_size = 32
epochs = 10
data_dir = 'dataset'  # Update this to your dataset folder

# Prepare the data generators with a validation split.

In [None]:
datagen = ImageDataGenerator(
    rescale=1./255, 
    validation_split=0.2,  # 20% for validation
    rotation_range=10,      # Optional augmentation
    width_shift_range=0.1,
    height_shift_range=0.1
)

# Read images as RGBA so we have access to the alpha channel.

In [None]:
train_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(img_width, img_height),
    color_mode='rgba',
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

validation_generator = datagen.flow_from_directory(
    data_dir,
    target_size=(img_width, img_height),
    color_mode='rgba',
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

# Print class indices for debugging.

In [None]:
print("Class mapping:", train_generator.class_indices)

# Define the CNN model with a Lambda layer to extract the alpha channel.

In [None]:
model = Sequential([
    # Lambda layer: extract the alpha channel (4th channel) and output shape becomes (28,28,1)
    Lambda(lambda x: x[..., 3:4], input_shape=(img_width, img_height, 4)),
    
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(64, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),

    Conv2D(128, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),

    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(10, activation='softmax')  # 10 output classes for digits 0-9
])

# Compile the model.

In [None]:
model.compile(
    optimizer='adam', 
    loss='categorical_crossentropy', 
    metrics=['accuracy']
)

# Optional callbacks to monitor training and save the best model.

In [None]:
callbacks = [
    EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True),
    ModelCheckpoint('best_digit_recognition_model.h5', save_best_only=True)
]

model.summary()

# Train the model.

In [None]:
history = model.fit(
    train_generator,
    epochs=epochs,
    validation_data=validation_generator,
    callbacks=callbacks
)

# Save the final trained model.

In [None]:
model.save('digit_recognition_model.h5')

# Load the trained model.

In [None]:
model = tf.keras.models.load_model('digit_recognition_model.h5', compile=False)

# Path to an example image (update the path as needed).

In [32]:
example_image_path = 'dataset/4/37.png'  # Adjust this path if necessary.

# Check if the file exists.

In [None]:
if not os.path.exists(example_image_path):
    raise FileNotFoundError(f"Image not found at {example_image_path}")

# Load the image in unchanged mode to capture all channels.

In [None]:
img = cv2.imread(example_image_path, cv2.IMREAD_UNCHANGED)

# Verify the image was loaded.

In [None]:
if img is None:
    raise ValueError("cv2.imread returned None. Check the file or its format.")

print("Original image shape:", img.shape)

# Ensure the image has 4 channels; if not, convert or expand dimensions appropriately.

In [None]:
if len(img.shape) == 2:
    # If image is grayscale, add a channel dimension and replicate for RGBA.
    img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGBA)
elif img.shape[-1] == 3:
    # If image is RGB, convert to RGBA.
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)

print("Image shape after conversion:", img.shape)

# Preprocess: resize, normalize, and reshape.

In [None]:
img = cv2.resize(img, (28, 28))
img = img / 255.0  # Normalize to [0, 1]
img = img.reshape(1, 28, 28, 4)  # Note: 4 channels because our model expects RGBA input

# Extract the alpha channel manually for visualization.

In [None]:
alpha_channel = img[0, ..., 3]
plt.imshow(alpha_channel, cmap='gray')
plt.title("Alpha Channel")
plt.axis('off')
plt.show()

# Make a prediction.

In [33]:
predictions = model.predict(img)
predicted_digit = np.argmax(predictions)
print(f'Predicted Digit: {predicted_digit}')

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 69ms/step
Predicted Digit: 3
