In [13]:
from model import fcn_8_vgg
from predict import predict_multiple  
from predict import class_colors

inp_dir = "dataset/test"      
out_dir = "predicted_image"  
model = fcn_8_vgg(n_classes=27, input_height=224, input_width=320)
model.load_weights('checkpoints/model.weights.h5')

In [14]:
from matplotlib import pyplot as plt
import six
import cv2
import numpy as np
from functions import get_image_array
from predict import visualize_segmentation

def predict(model=None, inp=None, out_fname=None,
            checkpoints_path=None, overlay_img=False,
            class_names=None, show_legends=False, colors=class_colors,
            prediction_width=None, prediction_height=None,
            read_image_type=1):
    
    assert inp is not None, "Input must be provided."
    assert isinstance(inp, (np.ndarray, six.string_types)), \
        "Input should be a NumPy array or a file path string."

    if isinstance(inp, six.string_types):
        inp = cv2.imread(inp, read_image_type)
        assert inp is not None, f"Image at path {inp} could not be loaded."

    assert inp.ndim in [1, 3, 4], "Image should have 1, 3, or 4 dimensions."

    output_width = model.output_width
    output_height = model.output_height
    input_width = model.input_width
    input_height = model.input_height
    n_classes = model.n_classes

    x = get_image_array(inp, input_width, input_height)
    pr = model.predict(np.array([x]))
    
    pr = pr.reshape((output_height, output_width, n_classes)).argmax(axis=-1)
    
    seg_img = visualize_segmentation(
        pr, inp, n_classes=n_classes, colors=colors
    )
    
    # Display the image
    # plt.imshow(seg_img)
    # plt.axis('off')  # Turn off axis numbers and ticks
    # plt.show()

    # Convert the seg_img to uint8 format (if necessary)
    if seg_img.dtype != np.uint8:
        seg_img = (seg_img * 255).astype(np.uint8)  # Scale if necessary
    
    # Check if the image is in RGB and convert to BGR if needed
    if seg_img.shape[2] == 3:  # Check if there are 3 channels
        seg_img = cv2.cvtColor(seg_img, cv2.COLOR_RGB2BGR)  # Convert from RGB to BGR

    if out_fname is not None:
        # Ensure the output file name has a .png extension
        if not out_fname.endswith('.png'):
            out_fname += '.png'
        
        success = cv2.imwrite(out_fname, seg_img)
        if success:
            print(f"Saved segmented image to {out_fname}")
        else:
            print(f"Failed to save image at {out_fname}")

    return seg_img


In [None]:
import os

def predict_and_save_images(model, test_folder, output_folder, class_colors):
    """Predict segmentation for all images in the test folder and save the outputs."""
    
    # Ensure the output folder exists
    os.makedirs(output_folder, exist_ok=True)

    # Iterate over each file in the test folder
    for file_name in os.listdir(test_folder):
        if file_name.endswith('.jpg') or file_name.endswith('.png'):
            input_path = os.path.join(test_folder, file_name)
            print(f"Processing {input_path}...")

            # Predict the segmentation
            segmented_image = predict(
                model=model,
                inp=input_path,
                colors=class_colors,
            )
            
            # Save the segmented image to the output folder
            output_path = os.path.join(output_folder, file_name)
            cv2.imwrite(output_path, segmented_image)
            print(f"Saved segmented image to {output_path}")

# Assuming you have a trained model and the class colors defined
# Model should be pre-loaded here as 'model'
# Define your class colors (replace with actual colors)
# Specify your test and output folder paths
test_folder = 'dataset/test'
output_folder = 'pred'

# Call the function to predict and save images
predict_and_save_images(model, test_folder, output_folder, class_colors)