In [1]:
import cv2
import matplotlib
import matplotlib.pyplot as plt
from roipoly import RoiPoly
import numpy as np
import os
matplotlib.use('TkAgg') 
import pickle


# #Training

In [None]:
def masking(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for display
    
    plt.imshow(image)
    plt.title("Select ROI")
    plt.show()
    
    roi = RoiPoly(color='k')  # Draw ROI on the image


    roi.display_roi()
    plt.imshow(image)
    plt.title("Selected ROI")
    plt.show()

    mask = roi.get_mask(np.zeros((image.shape[0], image.shape[1])))

    
    return mask

def rgb_values(image, mask):
    masked_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
    rgb_values = masked_image[mask > 0]#255 to match the filled mask value value will be either 0 or 255
    return rgb_values

In [None]:

image_folder = 'data'

RGB_values_all_classes = []

for class_index in range(0, 4):
    
    RGB_values_img = []
    
    for i in range(1, 8):  # Loop from index 1 to 7
        filename= f'{i}.jpg'  
    
        image_path = os.path.join(image_folder, filename)  # constructing file path
        image=cv2.imread(image_path)
    
        mask = masking(image)
        


        rgb_values_img=rgb_values(image,mask)
    
    
        RGB_values_img.extend(rgb_values_img.tolist())
    

    
    
    RGB_values_array =np.array(RGB_values_img)

    mean= np.mean(RGB_values_array, axis=0)
    covariance=np.cov(RGB_values_array,rowvar=False)
    
    RGB_values_all_classes.append({
        'class': class_index,
        'mean': mean,
        'covariance': covariance
    })
    
    print(f"Class {class_index}: Mean = {mean}, Covariance = {covariance}")




with open('RGB_values_all_classes_2.pkl', 'wb') as file:
    pickle.dump(RGB_values_all_classes, file)




# Segmentation

In [None]:
with open('RGB_values_all_classes_2.pkl', 'rb') as file:
    loaded_RGB_values_all_classes = pickle.load(file)
    
    
loaded_RGB_values_all_classes[0]['mean']=loaded_RGB_values_all_classes[0]['mean'].tolist()
loaded_RGB_values_all_classes[1]['mean']=loaded_RGB_values_all_classes[1]['mean'].tolist()
loaded_RGB_values_all_classes[2]['mean']=loaded_RGB_values_all_classes[2]['mean'].tolist()
loaded_RGB_values_all_classes[3]['mean']=loaded_RGB_values_all_classes[3]['mean'].tolist()

loaded_RGB_values_all_classes[0]['covariance']=loaded_RGB_values_all_classes[0]['covariance'].tolist()
loaded_RGB_values_all_classes[1]['covariance']=loaded_RGB_values_all_classes[1]['covariance'].tolist()
loaded_RGB_values_all_classes[2]['covariance']=loaded_RGB_values_all_classes[2]['covariance'].tolist()
loaded_RGB_values_all_classes[3]['covariance']=loaded_RGB_values_all_classes[3]['covariance'].tolist()


In [None]:
from scipy.stats import multivariate_normal

def applyGaussianToImage(image,mean,covariance):
    
    # converting image to double precision
    double_img = image.astype(np.float64)
    
    #reshaping image to a matrix containing all pixel rgb values (3)
    
    image_pixels=double_img.reshape(-1,3) # -1 indicates taking in account everything .. here all pixels
    
    #computing PDF for each pixel using multivariate Gaussian
    
    pdf_val = multivariate_normal.pdf(image_pixels, mean=mean, cov=covariance)  # this fucntion is part of scipy.stats module hence we import it
    
    
    # reshaping pdf values to image dimension
    
    pdf = pdf_val.reshape(image.shape[:2])
   
    return pdf
    

In [None]:
image_folder = 'data'

output_folder = 'segmented_images'
os.makedirs(output_folder, exist_ok=True)

In [None]:
for class_index in range(4):
    mean =  loaded_RGB_values_all_classes[class_index]['mean']
    covariance = loaded_RGB_values_all_classes[class_index]['covariance']
    
    # Let's save all segmented images of each class in a separate folder
    class_folder = os.path.join(output_folder, f'class_{class_index}')
    os.makedirs(class_folder, exist_ok=True)

    for i in range(1, 8):  # Loop from index 1 to 7
        filename= f'{i}.jpg'  
    
        image_path = os.path.join(image_folder, filename)  # constructing file path
   
        image=cv2.imread(image_path)

    
        gauss_image=applyGaussianToImage(image,mean,covariance)
    
        mask = gauss_image>1e-6   # return is Boolean and for opencv we need integer hence mask data type is converted
    
        final_image = cv2.bitwise_and(image, image, mask=mask.astype(np.uint8))
        
        # Save the segmented image
        output_filename = f'{i}.jpg'
        output_path = os.path.join(class_folder, output_filename)
        cv2.imwrite(output_path, final_image)
    
#         cv2.imshow(f'Segmented Image {i}', final_image)
#         cv2.waitKey(0)
#         cv2.destroyAllWindows()

In [None]:
# preparing output image grid


image_folder = 'data'
output_folder = 'segmented_images'

# Prepare the grid (7 rows, 5 columns)
fig, axes = plt.subplots(nrows=7, ncols=5, figsize=(20, 28))

# Process each image (1 to 7)
for i in range(7):  # Loop from 0 to 6
    # Load the original image
    filename = f'{i + 1}.jpg'
    image_path = os.path.join(image_folder, filename)
    original_image = cv2.imread(image_path)
    original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for displaying
    
    # Display the original image in the first column
    axes[i, 0].imshow(original_image)
    axes[i, 0].set_title(f'Original Image {i}')
    axes[i, 0].axis('off')
    
    # Display the segmented images for each class in the subsequent columns
    for class_index in range(4):  # Loop from 0 to 3
        class_folder = os.path.join(output_folder, f'class_{class_index}')
        segmented_image_path = os.path.join(class_folder, filename)
        segmented_image = cv2.imread(segmented_image_path)
        segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for displaying
        
        axes[i, class_index + 1].imshow(segmented_image)
        axes[i, class_index + 1].set_title(f'Class {class_index + 1}')
        axes[i, class_index + 1].axis('off')

# Adjust layout to prevent overlap
plt.tight_layout()

# Save the grid image
output_path = os.path.join(output_folder, 'segmentation_grid2.png')
plt.savefig(output_path, bbox_inches='tight', dpi=300)

# Display the grid
plt.show()
