In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, Input
import warnings

warnings.filterwarnings("ignore")

In [None]:
image_dir = "4/Banana FCN/Images"
mask_dir = "4/Banana FCN/Mask"

In [None]:
image_datagen = ImageDataGenerator(rescale=1./255)
mask_datagen = ImageDataGenerator(rescale=1./255)

image_generator = image_datagen.flow_from_directory(
    image_dir,
    class_mode=None,
    color_mode='rgb',
    target_size=(128, 128),
    batch_size=32,
)

mask_generator = mask_datagen.flow_from_directory(
    mask_dir,
    class_mode=None,
    color_mode='grayscale',
    target_size=(128, 128),
    batch_size=32,
)

In [None]:
train_generator = zip(image_generator, mask_generator)

In [None]:
def build_fcnn():
    inputs = Input((128, 128, 3))
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(inputs)
    pool1 = MaxPooling2D((2, 2))(conv1)
    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool1)
    pool2 = MaxPooling2D((2, 2))(conv2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    up1 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv3)
    conv4 = Conv2D(128, (3, 3), activation='relu', padding='same')(up1)
    up2 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv4)
    outputs = Conv2D(1, (1, 1), activation='sigmoid', padding='same')(up2)
    model = Model(inputs, outputs)
    return model

In [None]:
model = build_fcnn()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

def combined_generator(image_gen, mask_gen):
    while True:
        img_batch = next(image_gen)
        mask_batch = next(mask_gen)
        yield img_batch, mask_batch

train_generator = combined_generator(image_generator, mask_generator)
history = model.fit(train_generator, steps_per_epoch=len(image_generator), epochs=20)

In [None]:
history.history.keys()

In [None]:
plt.figure(figsize=(16,4))
plt.plot(history.history['accuracy'], label='Accuracy', color='blue', linewidth=2)
plt.plot(history.history['loss'], label='Loss', color='red', linewidth=2)
plt.title('Model Accuracy')
plt.legend()
plt.show()

In [None]:
def predict(sample_image):
    predicted_mask = model.predict(np.expand_dims(sample_image, axis=0))[0]

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(sample_image)

    plt.subplot(1, 2, 2)
    plt.title("Predicted Mask")
    plt.imshow(predicted_mask.squeeze(), cmap='gray')

    plt.show()
    plt.clf()

In [None]:
for sample in image_generator[0][:20]:
    predict(sample)