In [None]:
from skimage.morphology import skeletonize
from skimage.measure import LineModelND, ransac
from skimage.color import rgb2gray
from matplotlib import cm
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from src.openImage import openImage
import os


# Remove small shapes.
def cleanNoise(image, min_size_threshold=100):

    # Apply a different label for each shape.
    labeled_image, num_labels = ndimage.label(image)

    # Count the size of each labes which is the size of each shape. Remember to address the fact that 0 is not a real label and is background.
    sizes = np.bincount(labeled_image.ravel())

    # Exclude the background label (0) from the sizes array.
    sizes[0] = 0

    # Remove small connected components
    filtered_image = np.where(sizes[labeled_image] < min_size_threshold, 0, 1)

    return filtered_image


def fitImage(path_and_name_input, path_and_name_output=None):

    test_image = openImage(path_and_name_input)
    test_image.readImage()

    original_image = test_image.m_data

    # Black and white
    test_image.m_data = rgb2gray(test_image.m_data)
    test_image.m_data = np.where(test_image.m_data < 0.3, 0, 1)

    test_image.m_data = cleanNoise(test_image.m_data, 100)
    test_image.m_data = skeletonize(test_image.m_data)

    lines_eq = []

    ransac_data = test_image.m_data

    points = np.argwhere(ransac_data == 1)[:, ::-1]

    min_samples = 20

    lines_found_nr = 0

    for i in range(10):

        if (len(points) < min_samples + 1):
            # print(f"\nStopped at iteration number {i}\n\n")
            break

        try:
            model, inliers = ransac(
                points, LineModelND, min_samples=min_samples, residual_threshold=7, random_state=0)

            outliers = inliers == False

            if (type(model) is type(None)):
                break

            # ax[2].scatter(points[inliers][:, 0], points[inliers][:, 1])

            if (model.params[1][0] == 0):
                slope = 999  # temporary fix
            else:
                slope = model.params[1][1] / \
                    model.params[1][0]
            intercept = model.params[0][1] - \
                slope * model.params[0][0]

            points = points[outliers]

            lines_eq.append([slope, intercept])

            lines_found_nr += 1
        except:
            break

    if (lines_found_nr > 1):

        # print(f"found {lines_found_nr} for {path_and_name_output}")
        fig, axes = plt.subplots(1, 3, figsize=(15, 6))
        ax = axes.ravel()

        ax[0].imshow(original_image)
        ax[0].set_title('Original image')
        ax[0].set_axis_off()

        ax[1].imshow(test_image.m_data, cmap=cm.gray)
        ax[1].set_title('Input image')
        ax[1].set_axis_off()

        ax[2].imshow(test_image.m_data, cmap=cm.gray)
        ax[2].set_ylim((test_image.m_data.shape[0], 0))
        # ax[2].set_axis_off()
        ax[2].set_title('Detected lines')

        for i, (slope, intercept) in enumerate(lines_eq):

            x = np.arange(0, test_image.m_data.shape[1])
            y = slope * x + intercept
            ax[2].plot(x, y, label="fitted line", color='red')

        plt.tight_layout()
        plt.savefig(path_and_name_output)
        plt.close()


# Make sure the dirs exist befoe calling the function. Don't have time to make sure in code.
image_path = "/home/gant/Desktop/temp/CoBo_2018-06-20T14-17-17.236_0008/media/gant/Expansion/tpc_root_raw/DATA_ROOT/cleanimages/CoBo_2018-06-20T14-17-17.236_0008/"

image_list = [f for f in os.listdir(
    image_path) if os.path.isfile(os.path.join(image_path, f))]
image_list.sort()

output_path = "/home/gant/Desktop/temp/found/"


for image in image_list:
    try:
        fitImage((image_path + image), (output_path + image))
    except Exception as e:
        print(f"Error at {image}\nError code {str(e)}")
